"""
Specifies interface for Experiments which take in run parameters set by the user and Resources from a Stage, bring about some interaction between Resources, generate data for plotting and saving/loading to disk, and perform analysis which may lead to some characterization (i.e. gain in knowledge) of some Resource(s) parameters
"""

from collections import Counter
import contextlib
import dataclasses as dc
from datetime import datetime
from typing import Any

from labctrl.datasaver import DataSaver
from labctrl.dataset import Dataset
from labctrl.logger import logger
from labctrl.parameter import parametrize
from labctrl.plotter import Plotter
from labctrl.resource import Resource
from labctrl.settings import Settings
from labctrl.sweep import Sweep


class DatasetSpecificationError(Exception):
    """ """


class SweepSpecificationError(Exception):
    """ """


class ResourceSpecificationError(Exception):
    """ """


class ExperimentMetaclass(type):
    """ """

    def __init__(cls, name, bases, kwds) -> None:
        """ """
        super().__init__(name, bases, kwds)

        annotations = cls.__annotations__
        cls.resourcespec = [v for v in annotations.values() if issubclass(v, Resource)]
        cls.dataspec: dict[str, Dataset] = parametrize(cls, filter=Dataset)
        cls.sweepspec: dict[str, Sweep] = parametrize(cls, filter=Sweep)

    def __repr__(cls) -> str:
        """ """
        return f"<class '{cls.__name__}'>"


class Experiment(metaclass=ExperimentMetaclass):
    """
    Base class. Responsible for resource check and preparing datasets for saving and plotting.

    To write child classes, inherit Experiment, use class annotations to declare Resource spec, use class variables to declare Dataset(s), Sweep(s), Parameter(s). The names of these variables may be passed as arguments to __init__().

    class ExperimentSubclass(Experiment):
        # first indicate Resources used in the Experiment as annotations, e.g.

        instrument: Instrument
        sample: Resource

        # initialize all parameters that can be swept
        # including those which might not be swept during runtime.
        # set attributes "units", "dtype", and "save" as these do not change at runtime
        # note that sweep save settings cannot be controlled at runtime

        frequency = Sweep(units="Hz")
        power = Sweep(units="dBm")

        # initialize all Datasets this Experiment will generate (both raw and derived)
        # set attributes "axes", "units", "dtype", "chunks", "save", "plot" here
        # axes must be a list/tuple, even if it contains one entry
        # for plotted datasets, set "errfn" and "fitfn" too if needed
        # for derived datasets, set "datafn" too
        # whether or not the datasets are plotted / saved can be changed in run()
        # the dataset's outermost dimension will by default be repetitions (N) if N > 1
        # the number of repetitions N will be added as the outermost dimension of each dataset at runtime if N > 1
        # any uninitialized sweeps will be removed from the axes at runtime

        I = Dataset(axes=[power, frequency], units="AU")

        # TODO settle datafn, errfn, fitfn etc
    """

    def __init__(self, N: int, project: str, nametag: str = "") -> None:
        """
        N: (int) number of repetitions of this experiment
        project: (str) name of the project this experiment belongs to. This name is used as the name of the subfolder the datafile generated by this experiment is saved to.
        nametag: (str) optional suffix for the datafile name.

        the datafile is saved to the automatically generated path:
        datapath / <project> / <YY-MM-DD> / <HH-MM-SS_<experimentname>_<nametag>>.h5
        datapath is to be specified in labctrl Settings
        experimentname is automatically generated from the Experiment subclass' name
        project and nametag (optional) are given to the Experiment instance's __init__()
        """
        self.name = self.__class__.__name__
        self.N = N  # number of repetitions

        self.project = project
        self.nametag = nametag
        self._filepath = None

        # these are set by run()
        self.datasaver: DataSaver = None
        self.plotter: Plotter = None

    def __repr__(self) -> str:
        """ """
        return f"Experiment '{self.name}'"

    @property
    def filepath(self) -> str:
        """ """
        if self._filepath is None:
            date, time = datetime.now().strftime("%Y-%m-%d %H-%M-%S").split()
            foldername = Settings().datapath + f"/{self.project}/{date}/"
            filesuffix = f"_{self.nametag}" if self.nametag else ""
            filename = f"{time}_{self.name}{filesuffix}.h5"
            self._filepath = foldername + filename
        return self._filepath

    def snapshot(self) -> dict[str, Any]:
        """
        snapshot includes instance attributes that do not start with "_" are are not instances of excluded classes - Resource, Sweep, Dataset, DataSaver, LivePlotter
        """
        xcls = (Resource, Sweep, Dataset)  # excluded classes
        xkeys = ("datasaver", "plotter")  # excluded keys
        snapshot = {}
        for k, v in self.__dict__.items():
            if not isinstance(v, xcls) and not k.startswith("_") and not k in xkeys:
                snapshot[k] = v
        return snapshot

    @property
    def metadata(self) -> dict[str | None, dict[str, Any]]:
        """ """
        resources = [v for v in self.__dict__.values() if isinstance(v, Resource)]
        metadata = {resource.name: resource.snapshot() for resource in resources}
        return {**metadata, None: self.snapshot()}

    def run(
        self, save: tuple[Dataset] | None = None, plot: tuple[Dataset] | None = None
    ):
        """
        checks resources, prepares sweeps and datasets, sets which datasets to save/plot, enters context of datasaver and plotter/
        if save/plot are empty tuples or False, do not save/plot any dataset! if they are None, then we defer to the save/plot flags specified for datasets in the class definition. if they are tuples of one or more Dataset objects, then we only save/plot the specified datasets.
        """
        self._check_resources()
        self._prepare_sweeps()
        datasets = self._prepare_datasets(save=save, plot=plot)

        # plot/save as long as one dataset needs to be plotted/saved
        do_save = any(dataset.save for dataset in datasets)
        do_plot = any(dataset.plot for dataset in datasets)
        with contextlib.ExitStack() as stack:
            if do_save:
                datasaver = DataSaver(self.filepath, *datasets)
                self.datasaver = stack.enter_context(datasaver)
                self.datasaver.save_metadata(self.metadata)
                logger.debug("Set up datasaver and saved metadata!")
            if do_plot:
                plotter = Plotter(*datasets)
                self.plotter = stack.enter_context(plotter)
                logger.debug("Set up plotter!")
            logger.debug(f"Running {self.name} sequence...")
            self.sequence()

    def sequence(self) -> None:
        """
        the experimental sequence called by run(). in it, you can generate your expt data, know what pos to insert it in, and call:
        self.datasaver.save(self.<dataset_name>, data, pos)  # batch saving
        self.datasaver.save(self.<sweep_name>, data)  # whole dataset saving
        self.plotter.plot(self.<dataset_name>, data)
        """
        raise NotImplementedError("Subclass(es) must implement sequence()!")

    def _check_resources(self) -> None:
        """ """
        spec = [v.__class__ for v in self.__dict__.values() if isinstance(v, Resource)]
        spec = dict(Counter(spec))
        expectedspec = dict(Counter(self.__class__.resourcespec))
        if spec != expectedspec:
            message = f"Expect resource specification {expectedspec}, got {spec}."
            logger.error(message)
            raise ResourceSpecificationError(message)

    def _prepare_sweeps(self) -> None:
        """
        Prepare sweeps before running the experiment.

        Sweep preparation protocol:
        For each declared sweep name in the experiment's class definition, check if the experiment has an instance attribute of the same name. If not, throw an error. Else, if attribute is of type 'Sweep' - else it with the dtype, units, and name declared in the class definition.
        At runtime, these declared sweep variables may or may not be swept, but they must nevertheless be set as instance attributes of the experiment.
        """
        for name, sweep in self.__class__.sweepspec.items():
            sweep.name = name  # to identify sweep name from sweep object later on
            try:
                value = self.__dict__[name]
            except KeyError:
                message = (
                    f"Name '{name}' is declared as a Sweep variable of {self}"
                    f" but is not set as an attribute."
                )
                logger.error(message)
                raise SweepSpecificationError(message) from None
            else:
                if isinstance(value, Sweep):
                    changes = {"dtype": sweep.dtype, "units": sweep.units, "name": name}
                    sweep = dc.replace(value, **changes)
                    self.__dict__[name] = sweep
                    logger.debug(f"Found {sweep}.")

    def _prepare_datasets(
        self, save: tuple[Dataset] | None, plot: tuple[Dataset] | None
    ) -> list[Dataset]:
        """
        Prepare datasets before running the experiment.

        Dataset preparation protocol:
        For each dataset declared in the experiment's class definition, replace the Sweep objects in its 'axes' attribute with the prepared sweeps. If number of repetitions of the experiment > 1, add the dimension 'N' as the outermost sweep dimension. Set runtime dataset save/plot args.
        """
        sweeps = self.__class__.sweepspec.values()
        datasets: list[Dataset] = []

        for name, dataset in self.__class__.dataspec.items():
            dataset.name = name  # to identify dataset name from dataset object later on
            axes = []

            for sweep in dataset.axes:
                if sweep in sweeps:
                    # getattr as instance (not class) attribute is the prepared sweep
                    axes.append(getattr(self, sweep.name))
                else:
                    message = (
                        f"Invalid {sweep = } declared in Dataset {name} 'axes'. "
                        f"Axis must be of type 'Sweep'."
                    )
                    logger.error(message)
                    raise DatasetSpecificationError(message) from None

            # add number of repetitions dimension to the head of the axes list if N > 1
            if self.N > 1:
                n = Sweep(start=1, stop=self.N, num=self.N, save=False, name="N")
                axes.insert(0, n)

            dataset = dc.replace(dataset, axes=axes)
            setattr(self, name, dataset)  # used to save/plot datasets in sequence()
            datasets.append(dataset)
            logger.debug(f"Found {dataset}.")

        # set runtime dataset plot/save args
        savelist = (dataset.name for dataset in save) if save else ()
        plotlist = (dataset.name for dataset in plot) if plot else ()
        for dataset in datasets:
            dataset.save = dataset.save if save is None else dataset.name in savelist
            dataset.plot = dataset.plot if plot is None else dataset.name in plotlist

        return datasets
