import asyncio

from typing import TYPE_CHECKING, Dict, List

from kombu import Queue
from kombu.messaging import (
    Producer as SyncProducer,
    Consumer as SyncConsumer,
    is_connection,
    maybe_channel
)
from kombu.utils.functional import ChannelPromise

from celery.aio.common import maybe_declare
from celery.aio.utils import AsyncChannelPromise


class Producer(SyncProducer):
    async def publish(self, body, routing_key=None, delivery_mode=None,
                mandatory=False, immediate=False, priority=0,
                content_type=None, content_encoding=None, serializer=None,
                headers=None, compression=None, exchange=None, retry=False,
                retry_policy=None, declare=None, expiration=None, timeout=None,
                **properties):
        """Publish message to the specified exchange.

        Arguments:
            body (Any): Message body.
            routing_key (str): Message routing key.
            delivery_mode (enum): See :attr:`delivery_mode`.
            mandatory (bool): Currently not supported.
            immediate (bool): Currently not supported.
            priority (int): Message priority. A number between 0 and 9.
            content_type (str): Content type. Default is auto-detect.
            content_encoding (str): Content encoding. Default is auto-detect.
            serializer (str): Serializer to use. Default is auto-detect.
            compression (str): Compression method to use.  Default is none.
            headers (Dict): Mapping of arbitrary headers to pass along
                with the message body.
            exchange (kombu.entity.Exchange, str): Override the exchange.
                Note that this exchange must have been declared.
            declare (Sequence[EntityT]): Optional list of required entities
                that must have been declared before publishing the message.
                The entities will be declared using
                :func:`~kombu.common.maybe_declare`.
            retry (bool): Retry publishing, or declaring entities if the
                connection is lost.
            retry_policy (Dict): Retry configuration, this is the keywords
                supported by :meth:`~kombu.Connection.ensure`.
            expiration (float): A TTL in seconds can be specified per message.
                Default is no expiration.
            timeout (float): Set timeout to wait maximum timeout second
                for message to publish.
            **properties (Any): Additional message properties, see AMQP spec.
        """
        _publish = self._publish

        declare = [] if declare is None else declare
        headers = {} if headers is None else headers
        retry_policy = {} if retry_policy is None else retry_policy
        routing_key = self.routing_key if routing_key is None else routing_key
        compression = self.compression if compression is None else compression

        exchange_name, properties['delivery_mode'] = self._delivery_details(
            exchange or self.exchange, delivery_mode,
        )

        if expiration is not None:
            properties['expiration'] = str(int(expiration * 1000))

        body, content_type, content_encoding = self._prepare(
            body, serializer, content_type, content_encoding,
            compression, headers)

        if self.auto_declare and self.exchange.name:
            if self.exchange not in declare:
                # XXX declare should be a Set.
                declare.append(self.exchange)

        if retry:
            conn = await self.connection
            _publish = conn.ensure(self, _publish, **retry_policy)
        return await _publish(
            body, priority, content_type, content_encoding,
            headers, properties, routing_key, mandatory, immediate,
            exchange_name, declare, timeout
        )

    async def _publish(self, body, priority, content_type, content_encoding,
                 headers, properties, routing_key, mandatory,
                 immediate, exchange, declare, timeout=None):
        channel = await self.channel
        message = channel.prepare_message(
            body, priority, content_type,
            content_encoding, headers, properties,
        )
        if declare:
            maybe_declare = self.maybe_declare
            await asyncio.gather(*(
                maybe_declare(entity)
                for entity in declare
            ))

        # handle autogenerated queue names for reply_to
        reply_to = properties.get('reply_to')
        if isinstance(reply_to, Queue):
            properties['reply_to'] = reply_to.name
        return await channel.basic_publish(
            message,
            exchange=exchange, routing_key=routing_key,
            mandatory=mandatory, immediate=immediate,
            timeout=timeout
        )

    async def maybe_declare(self, entity, retry=False, **retry_policy):
        """Declare exchange if not already declared during this session."""
        if entity:
            return await maybe_declare(
                entity, await self.channel, retry, **retry_policy)

    async def _get_channel(self):
        channel = self._channel
        if isinstance(channel, AsyncChannelPromise):
            channel = self._channel = await channel()
            self.exchange.revive(channel)
            if self.on_return:
                channel.events['basic_return'].add(self.on_return)
        return channel

    def _set_channel(self, channel):
        self._channel = channel

    channel = property(_get_channel, _set_channel)

    def revive(self, channel):
        """Revive the producer after connection loss."""
        if is_connection(channel):
            connection = channel
            self.__connection__ = connection
            channel = AsyncChannelPromise(connection.default_channel)
        if isinstance(channel, ChannelPromise):
            self._channel = channel
            self.exchange = self.exchange(channel)
        else:
            # Channel already concrete
            self._channel = channel
            if self.on_return:
                self._channel.events['basic_return'].add(self.on_return)
            self.exchange = self.exchange(channel)

    async def __aenter__(self):
        return self

    async def __aexit__(self, *exc_info):
        self.release()

    def release(self):
        pass

    close = release

    async def _get_connection(self):
        if self.__connection__ is not None:
            return self.__connection__
        try:
            channel = await self.channel
            return channel.connection.client
        except AttributeError:
            pass

    @property
    def connection(self):
        return self._get_connection()


class Consumer(SyncConsumer):
    if TYPE_CHECKING:
        from celery.aio.entity import Queue

        _queues: Dict[str, Queue]
        @property
        def queues(self) -> List[Queue]: ...  # noqa

    def __init__(
        self, channel, queues=None, no_ack=None, auto_declare=None,
        callbacks=None, on_decode_error=None, on_message=None,
        accept=None, prefetch_count=None, tag_prefix=None
    ):
        super().__init__(
            None, queues, no_ack, auto_declare,
            callbacks, on_decode_error, on_message,
            accept, prefetch_count, tag_prefix
        )
        self.channel = channel

    async def init(self):
        if self.channel:
            await self.revive(self.channel)

    async def revive(self, channel):
        self._active_tags.clear()
        channel = self.channel = maybe_channel(channel)
        # modify dict size while iterating over it is not allowed
        for qname, queue in list(self._queues.items()):
            # name may have changed after declare
            self._queues.pop(qname, None)
            queue = self._queues[queue.name] = queue(self.channel)
            queue.revive(channel)

        if self.auto_declare:
            await self.declare()

        if self.prefetch_count is not None:
            self.qos(prefetch_count=self.prefetch_count)

    async def declare(self):
        for queue in self._queues.values():
            await queue.declare()

    def _basic_consume(self, queue: Queue, consumer_tag=None,
                       no_ack=None, nowait=True):
        tag = self._active_tags.get(queue.name)
        if tag is None:
            tag = self._add_tag(queue, consumer_tag)
            queue.consume(tag, self._receive_callback,
                          no_ack=no_ack, nowait=nowait)
        return tag
