from __future__ import annotations

from itertools import product
from typing import Any, Dict, Iterator, Optional, Sequence, Tuple, Union, no_type_check
from uuid import UUID, uuid4
from warnings import warn

import numpy as np
from pydantic import Field, PrivateAttr, root_validator, validator

from ._base_model import UseqModel
from ._channel import Channel
from ._mda_event import MDAEvent
from ._position import Position
from ._time import AnyTimePlan, NoT
from ._z import AnyZPlan, NoZ

TIME = "t"
CHANNEL = "c"
POSITION = "p"
Z = "z"
INDICES = (TIME, POSITION, CHANNEL, Z)

Undefined = object()


class MDASequence(UseqModel):
    """A sequence of MDA (Multi-Dimensional Acquisition) events.

    This is the core of the useq library, and is used define a sequence of events to be
    run on a microscope. It object may be constructed manually, or from file (e.g. json
    or yaml).

    The object itself acts as an iterator for `useq.MDAEvent` objects:

    Parameters
    ----------
    axis_order : str
        The order of the axes in the sequence. Must be a permutation of "tpcz". The
        default is "tpcz".
    stage_positions : Tuple[Position, ...]
        The stage positions to visit. (each with `x`, `y`, `z`, `name`, and `z_plan`,
        all of which are optional).
    channels : Tuple[Channel, ...]
        The channels to acquire. see `Channel`.
    time_plan : AnyTimePlan
        The time plan to follow. One of `TIntervalDuration`, `TIntervalLoops`,
        `TDurationLoops`, `MultiPhaseTimePlan`, or `NoT`
    z_plan : AnyZPlan
        The z plan to follow. One of `ZTopBottom`, `ZRangeAround`, `ZAboveBelow`,
        `ZRelativePositions`, `ZAbsolutePositions`, or `NoZ`.

    Examples
    --------
    >>> from useq import MDASequence, Position, Channel, TIntervalDuration
    >>> seq = MDASequence(
    ...     time_plan={"interval": 0.1, "loops": 2},
    ...     stage_positions=[(1, 1, 1)],
    ...     z_plan={"range": 3, "step": 1},
    ...     channels=[{"config": "DAPI", "exposure": 1}]
    ... )
    >>> print(seq)
    Multi-Dimensional Acquisition ▶ nt: 2, np: 1, nc: 1, nz: 4

    >>> for event in seq:
    ...     print(event)

    >>> print(seq.yaml())
    channels:
    - config: DAPI
      exposure: 1.0
    stage_positions:
    - x: 1.0
      y: 1.0
      z: 1.0
    time_plan:
      interval: '0:00:00.100000'
      loops: 2
    z_plan:
      range: 3.0
      step: 1.0
    """

    metadata: Dict[str, Any] = Field(default_factory=dict)
    axis_order: str = "".join(INDICES)
    stage_positions: Tuple[Position, ...] = Field(default_factory=tuple)
    channels: Tuple[Channel, ...] = Field(default_factory=tuple)
    time_plan: AnyTimePlan = Field(default_factory=NoT)
    z_plan: AnyZPlan = Field(default_factory=NoZ)

    _uid: UUID = PrivateAttr(default_factory=uuid4)
    _length: Optional[int] = PrivateAttr(default=None)

    @property
    def uid(self) -> UUID:
        """A unique identifier for this sequence."""
        return self._uid

    @no_type_check
    def replace(
        self,
        metadata: Dict[str, Any] = Undefined,
        axis_order: str = Undefined,
        stage_positions: Tuple[Position, ...] = Undefined,
        channels: Tuple[Channel, ...] = Undefined,
        time_plan: AnyTimePlan = Undefined,
        z_plan: AnyZPlan = Undefined,
    ) -> MDASequence:
        """Return a new `MDAsequence` replacing specified fields with new values."""
        kwargs = {k: v for k, v in locals().items() if v is not Undefined}
        state = self.dict(exclude={"uid"})
        return type(self)(**{**state, **kwargs})

    def __hash__(self) -> int:
        return hash(self.uid)

    @validator("z_plan", pre=True)
    def validate_zplan(cls, v: Any) -> Union[dict, NoZ]:
        return v or NoZ()

    @validator("time_plan", pre=True)
    def validate_time_plan(cls, v: Any) -> Union[dict, NoT]:
        return {"phases": v} if isinstance(v, (tuple, list)) else v or NoT()

    @validator("stage_positions", pre=True)
    def validate_positions(cls, v: Any) -> list:
        if isinstance(v, np.ndarray):
            if v.ndim == 1:
                return [v]
            elif v.ndim == 2:
                return list(v)
        return v

    @validator("axis_order", pre=True)
    def validate_axis_order(cls, v: Any) -> str:
        if not isinstance(v, str):
            raise TypeError(f"acquisition order must be a string, got {type(v)}")
        order = v.lower()
        extra = {x for x in order if x not in INDICES}
        if extra:
            raise ValueError(
                f"Can only iterate over axes: {INDICES!r}. Got extra: {extra}"
            )
        if len(set(order)) < len(order):
            raise ValueError(f"Duplicate entries found in acquisition order: {order}")

        return order

    @root_validator
    def validate_mda(cls, values: Dict[str, Any]) -> Dict[str, Any]:
        if "axis_order" in values:
            values["axis_order"] = cls._check_order(
                values["axis_order"],
                z_plan=values.get("z_plan"),
                stage_positions=values.get("stage_positions", ()),
                channels=values.get("channels", ()),
            )

        return values

    def __eq__(self, other: Any) -> bool:
        if isinstance(other, MDASequence):
            return self.dict(exclude={"uid"}) == other.dict(exclude={"uid"})
        else:
            return False

    @staticmethod
    def _check_order(
        order: str,
        z_plan: AnyZPlan = None,
        stage_positions: Sequence[Position] = (),
        channels: Sequence[Channel] = (),
    ) -> str:
        if (
            Z in order
            and POSITION in order
            and order.index(Z) < order.index(POSITION)
            and z_plan
            and any(p.z_plan for p in stage_positions)
        ):
            raise ValueError(
                f"{Z!r} cannot precede {POSITION!r} in acquisition order if "
                "any position specifies a z_plan"
            )

        if (
            CHANNEL in order
            and TIME in order
            and any(c.acquire_every > 1 for c in channels)
            and order.index(CHANNEL) < order.index(TIME)
        ):
            warn(
                f"Channels with skipped frames detected, but {CHANNEL!r} precedes "
                "{TIME!r} in the acquisition order: may not yield intended results."
            )

        return order

    def __str__(self) -> str:
        shape = [
            f"n{k.lower()}: {len(list(self.iter_axis(k)))}" for k in self.axis_order
        ]
        return "Multi-Dimensional Acquisition ▶ " + ", ".join(shape)

    def __len__(self) -> int:
        if self._length is None:
            self._length = len(list(self.iter_events()))
        return self._length

    @property
    def shape(self) -> Tuple[int, ...]:
        # NOTE: Doesn't account for jagged arrays, like skipped Z or channel frames
        return tuple(s for s in self.sizes.values() if s)

    @property
    def sizes(self) -> Dict[str, int]:
        return {k: len(list(self.iter_axis(k))) for k in self.axis_order}

    @property
    def used_axes(self) -> str:
        return "".join(k for k in self.axis_order if self.sizes[k])

    def iter_axis(self, axis: str) -> Iterator[Union[Position, Channel, float]]:
        yield from {
            TIME: self.time_plan,
            POSITION: self.stage_positions,
            Z: self.z_plan,
            CHANNEL: self.channels,
        }[axis]

    def __iter__(self) -> Iterator[MDAEvent]:  # type: ignore
        yield from self.iter_events()

    class _SkipFrame(Exception):
        pass

    def iter_events(self) -> Iterator[MDAEvent]:
        """Iterate over all events in the MDA sequence.

        This does the job of iterating over all the frames in the MDA sequence,
        handling the merging of z plans in channels and stage positions.

        Yields
        ------
        Iterator[MDAEvent]
        """
        order = self.used_axes

        event_iterator = (enumerate(self.iter_axis(ax)) for ax in order)
        for global_index, item in enumerate(product(*event_iterator)):
            if not item:  # the case with no events
                continue

            _ev = dict(zip(order, item))
            index = {k: _ev[k][0] for k in INDICES if k in _ev}

            position: Optional[Position] = _ev[POSITION][1] if POSITION in _ev else None
            channel: Optional[Channel] = _ev[CHANNEL][1] if CHANNEL in _ev else None
            time: Optional[int] = _ev[TIME][1] if TIME in _ev else None

            # skip channels
            if channel and TIME in index and index[TIME] % channel.acquire_every:
                continue

            try:
                z_pos = (
                    self._combine_z(_ev[Z][1], index[Z], channel, position)
                    if Z in _ev
                    else position.z
                    if position
                    else None
                )
            except self._SkipFrame:
                continue

            _channel = (
                {"config": channel.config, "group": channel.group} if channel else None
            )
            yield MDAEvent(
                index=index,
                min_start_time=time,
                pos_name=getattr(position, "name", None),
                x_pos=getattr(position, "x", None),
                y_pos=getattr(position, "y", None),
                z_pos=z_pos,
                exposure=getattr(channel, "exposure", None),
                channel=_channel,
                sequence=self,
                global_index=global_index,
            )

    def _combine_z(
        self,
        z_pos: float,
        z_ind: int,
        channel: Optional[Channel],
        position: Optional[Position],
    ) -> float:
        if channel:
            # only acquire on the middle plane:
            if not channel.do_stack and z_ind != len(self.z_plan) // 2:
                raise self._SkipFrame()
            if channel.z_offset is not None:
                z_pos += channel.z_offset
        if self.z_plan.is_relative:
            # TODO: either disallow without position z, or add concept of "current"
            z_pos += getattr(position, Z, None) or 0
        return z_pos

    def to_pycromanager(self) -> list[dict]:
        return [event.to_pycromanager() for event in self]
