from __future__ import annotations

from asyncio import Future
from typing import AsyncIterable, AsyncIterator, Iterable, TypeAlias, TypeVar

from astream.stream import Stream
from astream.utils import ensure_async_iterator

T = TypeVar("T")

ItemAndNextFuture: TypeAlias = Future[tuple[T, "ItemAndNextFuture[T]"]]


def atee(source: AsyncIterable[T], n_clones: int) -> tuple[Stream[T], ...]:
    """Create n clones of an async iterable, each receiving every item generated by the source.

    Note that this is not a true copy, as the source iterable is not duplicated. Instead, each
    copy is a separate iterator that receives the same items from the source iterable. The elements
    are not copied, so if a list is yielded from the source iterable, each copy will receive the
    same list object.

    The original iterable should not be iterated over after the copies are created, as this will
    cause the copies to miss items.

    Objects yielded from the source iterable will also stay in memory until all copies have
    finished iterating over them. This can cause memory issues in a long-running program if the
    copied async iterators are not consumed at the same rate as the source iterable.
      Todo - Perhaps use weak references to allow the source iterable to be garbage collected

    Args:
        source: The source iterable to copy.
        n_clones: The number of copies to create.

    Returns:
        A tuple of async iterators, each receiving the same items from the source iterable.
    """
    futures = [ItemAndNextFuture[T]() for _ in range(n_clones)]

    async def _cloned_aiter(future: ItemAndNextFuture[T]) -> AsyncIterator[T]:
        """Helper function for `clone_async_iterator`."""
        while True:
            try:
                item, future = await future
                yield item
            except asyncio.CancelledError:
                return

    async def _copier() -> None:
        futs = futures

        async for item in source:

            new_futs: list[ItemAndNextFuture[T]] = []
            for fut in futs:
                next_fut = ItemAndNextFuture[T]()
                new_futs.append(next_fut)
                fut.set_result((item, next_fut))
            futs = new_futs

        for fut in futs:
            fut.cancel(StopAsyncIteration)

    asyncio.create_task(_copier())
    copied = tuple(Stream(_cloned_aiter(fut)) for fut in futures)
    return copied


class ClonableAsyncIterableWrapper(AsyncIterable[T]):
    def __init__(self, source: AsyncIterable[T] | Iterable[T]) -> None:
        self._source = ensure_async_iterator(source, to_thread=True)

    def aclone(self) -> Stream[T]:
        """Create a clone of this async iterable.

        Returns:
            An async iterator, each receiving the same items from the source iterable.
        """
        self._source, clone = atee(self._source, 2)
        return clone

    def __aiter__(self) -> Stream[T]:
        return Stream(self._source)


__all__ = ("atee", "ClonableAsyncIterableWrapper")


if __name__ == "__main__":
    import asyncio

    from astream.stream_utils import amerge, arange

    async def main() -> None:
        # async def source() -> AsyncIterator[int]:
        #     for _ in range(10):
        #         await asyncio.sleep(random.random() / 10)
        #         yield random.randint(0, 100)
        #
        # a, b = tee_async_iterable(source(), 2)
        # async for i in amerge(a, b):
        #     print(i)

        it = ClonableAsyncIterableWrapper(arange(50))
        a = it.aclone()
        b = it.aclone() / (lambda x: x * 2)

        async for num in amerge(a, b, it):
            print(num)

    asyncio.run(main())

"""
The Python __index__ method is used to convert an object to an integer. 
It is called when an object is used in a context where an integer is required, such as 
when indexing a sequence or slicing a sequence.
"""
