# Copyright (C) 2024 Andrea Raffo <andrea.raffo@ibv.uio.no>
#
# SPDX-License-Identifier: MIT

from typing import Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import numpy.typing as npt
import pandas as pd

from stripepy.data_structures import Stripe


class Result(object):
    """
    A class used to represent the results generated by stripepy call.
    """

    def __init__(self, chrom_name: str, chrom_size: int):
        """
        Parameters
        ----------

        chrom
            chromosome name
        """
        assert chrom_size > 0

        self._chrom = (chrom_name, chrom_size)
        self._roi = None
        self._min_persistence = None

        self._ut_all_minimum_points = None
        self._ut_all_maximum_points = None
        self._ut_persistence_of_all_minimum_points = None
        self._ut_persistence_of_all_maximum_points = None

        self._lt_all_minimum_points = None
        self._lt_all_maximum_points = None
        self._lt_persistence_of_all_minimum_points = None
        self._lt_persistence_of_all_maximum_points = None

        self._ut_persistent_minimum_points = None
        self._ut_persistent_maximum_points = None
        self._ut_persistence_of_minimum_points = None
        self._ut_persistence_of_maximum_points = None
        self._ut_pseudodistribution = None

        self._lt_persistent_minimum_points = None
        self._lt_persistent_maximum_points = None
        self._lt_persistence_of_minimum_points = None
        self._lt_persistence_of_maximum_points = None
        self._lt_pseudodistribution = None

        self._ut_stripes = None
        self._lt_stripes = None

    @property
    def _valid_attributes(self) -> List[str]:
        """
        Get the list of valid attributes
        """
        return [a.removeprefix("_lt_") for a in dir(self) if a.startswith("_lt_")]

    @property
    def empty(self) -> bool:
        """
        Check whether any stripe has been registered with the :py:class:`Result` instance
        """
        lower_is_empty = self._lt_stripes is None or len(self._lt_stripes) == 0
        upper_is_empty = self._ut_stripes is None or len(self._ut_stripes) == 0
        return lower_is_empty and upper_is_empty

    @property
    def chrom(self) -> Tuple[str, int]:
        """
        The name and length of the chromosomes to which the :py:class:`Result` instance belongs to
        """
        return self._chrom

    @property
    def roi(self) -> Optional[Dict[str, List[int]]]:
        """
        The region of interest associated with the :py:class:`Result` instance
        """
        return self._roi

    @property
    def min_persistence(self) -> float:
        """
        The minimum persistence used during computation
        """
        if self._min_persistence is None:
            raise RuntimeError('Attribute "min_persistence" is not set')

        return self._min_persistence

    def get(self, name: str, location: str) -> Union[List[Stripe], npt.NDArray[int], npt.NDArray[float]]:
        """
        Get the value associated with the given attribute name and location.

        Parameters
        ----------
        name
            name of the attribute to be fetched
        location
            location of the attribute to be fetched. Should be "LT" or "UT"

        Returns
        -------
        the value associated with the given name and location.
        """
        if location == "lower":
            location = "LT"
        elif location == "upper":
            location = "UT"
        elif location not in {"LT", "UT"}:
            raise ValueError("Location should be UT or LT")

        attr_name = f"_{location.lower()}_{name}"
        if not hasattr(self, attr_name):
            raise AttributeError(
                f"No attribute named \"{name}\". Valid attributes are: {', '.join(self._valid_attributes)}"
            )

        attr = getattr(self, attr_name)
        if name == "stripes" and attr is None:
            return []

        if attr is None:
            raise RuntimeError(f'Attribute "{name}" for "{location}" is not set')

        return attr

    def get_stripes_descriptor(self, descriptor: str, location: str) -> Union[npt.NDArray[float], npt.NDArray[int]]:
        """
        Get the stripe descriptor for the given location.

        Parameters
        ----------
        descriptor
            name of the descriptor to be fetched
        location
            location of the attribute to be fetched. Should be "LT" or "UT"

        Returns
        -------
        the value associated with the given descriptor and location.
        """
        if location not in {"LT", "UT"}:
            raise ValueError("Location should be UT or LT")

        if not hasattr(Stripe, descriptor):
            raise AttributeError(f'Stripe instance does not have an attribute named "{descriptor}"')

        stripes = self.get("stripes", location)

        if descriptor in {
            "seed",
            "left_bound",
            "right_bound",
            "top_bound",
            "bottom_bound",
            "outer_lsize",
            "outer_rsize",
        }:
            dtype = int
        else:
            dtype = float

        return np.array([getattr(stripe, descriptor) for stripe in stripes], dtype=dtype)

    def get_stripe_geo_descriptors(self, location: str) -> pd.DataFrame:
        """
        Fetch all geometric descriptors at once.

        Parameters
        ----------
        location
            location of the attribute to be fetched. Should be "LT" or "UT"

        Returns
        -------
        the table with the geometric descriptors associated with the :py:class:`Result` instance
        """
        descriptors = [
            "seed",
            "top_persistence",
            "left_bound",
            "right_bound",
            "top_bound",
            "bottom_bound",
        ]

        return pd.DataFrame(
            {descriptor: self.get_stripes_descriptor(descriptor, location) for descriptor in descriptors}
        )

    def get_stripe_bio_descriptors(self, location: str) -> pd.DataFrame:
        """
        Fetch all biological descriptors at once.

        Parameters
        ----------
        location
            location of the attribute to be fetched. Should be "LT" or "UT"

        Returns
        -------
        the table with the biological descriptors associated with the :py:class:`Result` instance
        """
        descriptors = [
            "inner_mean",
            "outer_mean",
            "rel_change",
            "inner_std",
        ]

        return pd.DataFrame(
            {descriptor: self.get_stripes_descriptor(descriptor, location) for descriptor in descriptors}
        )

    def set_roi(self, coords: Dict[str, List[int]]):
        """
        Set the region of interest (RoI) for the current :py:class:`Result` instance.

        Parameters
        ----------
        coords
            a dictionary with the coordinates of the region of interest.
            The dictionary should contain two keys: "genomic" and "matrix".
            The value associated with the "genomic" key should be a list of 4 integers
            representing the genomic coordinates of the region of interest.
            The value associated with the "matrix" key should be a list of 4 integers
            representing the matrix coordinates of the region of interest.
        """
        if self._roi is not None:
            raise RuntimeError("roi has already been set")

        self._roi = coords

    def set_min_persistence(self, min_persistence: float):
        """
        Set the minimum persistence used during computation.

        Parameters
        ----------
        min_persistence
        """
        if self._min_persistence is not None:
            raise RuntimeError("min_persistence has already been set")

        self._min_persistence = min_persistence

    def set(
        self,
        name: str,
        data: Union[Sequence[int], Sequence[float], Sequence[Stripe]],
        location: str,
        force: bool = False,
    ):
        """
        Set the attribute corresponding to the given attribute name and location.

        Parameters
        ----------
        name
            name of the attribute to be set.
            Supported attributes are:

            * all_minimum_points
            * all_maximum_points
            * persistence_of_all_minimum_points
            * persistence_of_all_maximum_points
            * persistent_minimum_points
            * persistent_maximum_points
            * persistence_of_minimum_points
            * persistence_of_maximum_points
            * pseudodistribution
            * stripes

        data
            data to be registered with the :py:class:`Result` instance
        location
            location of the attribute to be registered. Should be "LT" or "UT"
        force
            force overwrite existing values
        """
        if location == "lower":
            location = "LT"
        elif location == "upper":
            location = "UT"
        elif location not in {"LT", "UT"}:
            raise ValueError("Location should be UT or LT")

        attr_name = f"_{location.lower()}_{name}"
        if not hasattr(self, attr_name):
            raise AttributeError(
                f"No attribute named \"{name}\". Valid attributes are: {', '.join(self._valid_attributes)}"
            )

        if not force and getattr(self, attr_name) is not None:
            raise RuntimeError(f'Attribute "{name}" for {location} has already been set')

        setattr(self, attr_name, np.array(data))
