# Copyright 2023 Q-CTRL. All rights reserved.
#
# Licensed under the Q-CTRL Terms of service (the "License"). Unauthorized
# copying or use of this file, via any medium, is strictly prohibited.
# Proprietary and confidential. You may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
#    https://q-ctrl.com/terms
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS. See the
# License for the specific language.
from abc import ABC
from dataclasses import (
    dataclass,
    field,
)
from typing import Optional

import numpy as np
from qctrlcommons.preconditions import check_argument

from boulderopal._validation import (
    ArrayDType,
    ScalarDType,
    nullable,
)


@dataclass
class FilterFunction:
    """
    A class to store information about the controls applied in the noise
    reconstruction process, in the form of filter functions.

    The filter function specifies how sensitive a set of controls is to a specific
    noise, at a certain frequency. Filter functions can be calculated using
    the ``execute_graph`` function together with the ``graph.filter_function``
    operator. The output of that operation can then be passed directly to this
    function.

    Parameters
    ----------
    frequencies : np.ndarray
        A 1D array of the frequencies where the filter function is sampled.
        The frequencies must be provided in ascending order.
    inverse_powers : np.ndarray
        A 1D arrays of the values of the filter function at the frequencies
        where it is sampled. Must have the same length as `frequencies`,
        and all its values must be greater or equal to zero.
    uncertainties : np.ndarray or None, optional
        The uncertainties associated with each sampled point of the filter
        function. These values are not used for noise reconstruction.
    """

    frequencies: np.ndarray
    inverse_powers: np.ndarray
    uncertainties: Optional[np.ndarray] = None

    def __post_init__(self):
        self.frequencies = ArrayDType.REAL(self.frequencies, "frequencies", ndim=1)
        self.inverse_powers = ArrayDType.REAL(
            self.inverse_powers, "inverse_powers", ndim=1, min_=0, min_inclusive=True
        )
        self.uncertainties = nullable(
            ArrayDType.REAL,
            self.uncertainties,
            "uncertainties",
            ndim=1,
            min_=0,
            min_inclusive=True,
        )

        check_argument(
            len(self.frequencies) == len(self.inverse_powers),
            "The frequencies must have the same length as the inverse powers.",
            {"frequencies": self.frequencies, "inverse_powers": self.inverse_powers},
            extras={
                "len(frequencies)": len(self.frequencies),
                "len(inverse_powers)": len(self.inverse_powers),
            },
        )
        check_argument(
            all(np.diff(self.frequencies) > 0),
            "The frequencies must be provided in ascending order.",
            {"frequencies": self.frequencies},
        )
        self._sample_count = len(self.frequencies)

    @property
    def sample_count(self):
        """
        The number of samples in the filter function.
        """
        return self._sample_count


@dataclass
class NoiseReconstructionMethod(ABC):
    """
    Base class for the noise reconstruction methods.
    """

    method_name: str

    def __setattr__(self, name, value):
        if name == "method_name":
            raise RuntimeError(
                "Mutating the `method_name` of the optimizer is not allowed."
            )
        super().__setattr__(name, value)


@dataclass
class ConvexOptimization(NoiseReconstructionMethod):
    r"""
    Configuration for noise reconstruction with the convex optimization (CVX) method.

    Parameters
    ----------
    power_density_lower_bound : float
        The lower bound for the reconstructed power spectral densities.
        It must be greater than or equal to 0.
    power_density_upper_bound : float
        The upper bound for the reconstructed power spectral densities.
        It must be greater than the `power_density_lower_bound`.
    regularization_hyperparameter : float
        The regularization hyperparameter :math:`\lambda`.

    Notes
    -----
    The CVX method finds the estimation of the power spectral density (PSD) matrix
    :math:`{\mathbf S}` by solving the optimization problem:

    .. math::
        {\mathbf S}_{\mathrm{est}} = \mathrm{argmin}_{\textbf S} (\| F'{\mathbf S} -
        {\mathbf I} \|_2^2 + \lambda \| L_1 {\mathbf S} \|_2^2) ,

    where :math:`F^\prime` is the matrix of weighted filter functions and
    :math:`\| \bullet \|_2` denotes the Euclidean norm and :math:`L_1` is the
    first-order derivative operator defined as

    .. math::
        \begin{align}
            L_1 =
              \begin{bmatrix}
                -1 &      1 &         &    \\
                   & \ddots &  \ddots &     \\
                   &        &      -1 & 1    \\
              \end{bmatrix}_{(K - 1) \times K} .
        \end{align}

    :math:`\lambda` is a positive regularization hyperparameter which determines the
    smoothness of :math:`{\mathbf S}_{\mathrm{est}}`. If you provide uncertainties in
    measurements, this method calculates the uncertainties in estimation using a Monte
    Carlo method.
    """

    power_density_lower_bound: float
    power_density_upper_bound: float
    regularization_hyperparameter: float
    method_name: str = field(default="convex optimization", init=False)

    def __post_init__(self):
        self.power_density_lower_bound = ScalarDType.REAL(
            self.power_density_lower_bound,
            "power_density_lower_bound",
            min_=0,
            min_inclusive=True,
        )
        self.power_density_upper_bound = ScalarDType.REAL(
            self.power_density_upper_bound, "power_density_upper_bound"
        )
        check_argument(
            self.power_density_lower_bound < self.power_density_upper_bound,
            "The power density lower bound must be less than the upper bound.",
            {
                "self.power_density_lower_bound": self.power_density_lower_bound,
                "self.power_density_upper_bound": self.power_density_upper_bound,
            },
        )

        self.regularization_hyperparameter = ScalarDType.REAL(
            self.regularization_hyperparameter,
            "regularization_hyperparameter",
            min_=0,
            min_inclusive=True,
        )


@dataclass
class SVDEntropyTruncation(NoiseReconstructionMethod):
    r"""
    Configuration for noise reconstruction with the singular value decomposition
    (SVD) method using entropy truncation.

    Parameters
    ----------
    rounding_threshold : float, optional
        The rounding threshold of the entropy, between 0 and 1 (inclusive).
        Defaults to 0.5.

    Notes
    -----
    The singular value decomposition (SVD) method first finds a low rank approximation
    of the matrix of weighted filter functions :math:`F^\prime`:

    .. math::
        F^\prime \approx U \Sigma V ,

    where matrices :math:`U` and :math:`V` satisfy that
    :math:`U^\dagger U = VV^\dagger = \mathbb{I}_{n_{\mathrm{sv}} \times n_{\mathrm{sv}}}`,
    and :math:`\Sigma` is a diagonal matrix of :math:`n_{\mathrm{sv}}` truncated
    singular values, which in the entropy truncation method are determined by
    the entropy of the singular values :math:`E`.

    The entropy truncation method calculates the value :math:`2^E` and rounds the
    value to an integer :math:`n_{\mathrm{sv}}`. When rounding the value
    :math:`2^E`, the floor of :math:`2^E` plus the rounding threshold that you
    chose is taken. Therefore a small value leads to rounding down, while a
    large value leads to rounding up. The :math:`n_{\mathrm{sv}}` is then used
    as the truncation value.

    The SVD method then estimates the noise power spectral density (PSD) :math:`\mathbf S` as:

    .. math::
        {\mathbf S}_{\mathrm{est}} = V^\dagger\Sigma^{-1}U^\dagger{\mathbf I} .

    This method calculates the uncertainties in estimation using error propagation if
    you provide measurement uncertainties.
    """

    rounding_threshold: float = 0.5
    method_name: str = field(default="SVD entropy truncation", init=False)

    def __post_init__(self):
        self.rounding_threshold = ScalarDType.REAL(
            self.rounding_threshold,
            "rounding_threshold",
            min_=0,
            max_=1,
            min_inclusive=True,
            max_inclusive=True,
        )


@dataclass
class SVDFixedLengthTruncation(NoiseReconstructionMethod):
    r"""
    Configuration for noise reconstruction with the singular value decomposition
    (SVD) method using fixed-length truncation.

    Parameters
    ----------
    singular_value_count : int or None, optional
        The number of singular values to retain. It must be greater or equal to 1.
        Defaults to None, in which case no truncation is performed.

    Notes
    -----
    The singular value decomposition (SVD) method first finds a low rank approximation
    of the matrix of weighted filter functions :math:`F^\prime`:

    .. math::
        F^\prime \approx U \Sigma V ,

    where matrices :math:`U` and :math:`V` satisfy that
    :math:`U^\dagger U = VV^\dagger = \mathbb{I}_{n_{\mathrm{sv}} \times n_{\mathrm{sv}}}`,
    and :math:`\Sigma` is a diagonal matrix of :math:`n_{\mathrm{sv}}` truncated
    singular values, which are determined by the `singular_value_count` that you
    provided.

    The SVD method then estimates the noise power spectral density (PSD) :math:`\mathbf S` as:

    .. math::
        {\mathbf S}_{\mathrm{est}} = V^\dagger\Sigma^{-1}U^\dagger{\mathbf I} .

    This method calculates the uncertainties in estimation using error propagation if
    you provide measurement uncertainties.
    """

    singular_value_count: Optional[int] = None
    method_name: str = field(default="SVD fixed-length truncation", init=False)

    def __post_init__(self):
        self.singular_value_count = nullable(
            ScalarDType.INT,
            self.singular_value_count,
            "singular_value_count",
            min_=1,
            min_inclusive=True,
        )


ALLOWED_NOISE_RECONSTRUCTION_METHODS = (
    ConvexOptimization,
    SVDEntropyTruncation,
    SVDFixedLengthTruncation,
)
