import numpy as np
from scipy.linalg import solve

from typing import Tuple

class DiffTVR:

    def __init__(self, n: int, dx: float):
        """Differentiate with TVR.

        Args:
            n (int): Number of points in data.
            dx (float): Spacing of data.
        """
        self.n = n
        self.dx = dx

        self.d_mat = self._make_d_mat()
        self.a_mat = self._make_a_mat()
        self.a_mat_t = self._make_a_mat_t()

    def _make_d_mat(self) -> np.array:
        """Make differentiation matrix with Euler differences. NOTE: not efficient!

        Returns:
            np.array: N-1 x N
        """
        arr = np.zeros((self.n-1,self.n))
        for i in range(0,self.n-1):
            for j in range(0,self.n):
                if i == j:
                    arr[i,j] = -1.0
                elif i == j-1:
                    arr[i,j] = 1.0
        return arr / self.dx

    def _make_a_mat(self) -> np.array:
        """Make integration matrix. NOTE: not efficient!

        Returns:
            np.array: N-1 x N-1
        """
        arr = np.zeros((self.n-1,self.n-1))
        for i in range(0,self.n-1):
            for j in range(0,self.n-1):
                if i >= j:
                    arr[i,j] = 1.0

        return arr * self.dx

    def _make_a_mat_t(self) -> np.array:
        """Transpose of the integration matirx with rule. NOTE: not efficient!

        Returns:
            np.array: N-1 x N-1
        """
        return np.transpose(self._make_a_mat())

    def make_en_mat(self, deriv_curr : np.array) -> np.array:
        """Diffusion matrix

        Args:
            deriv_curr (np.array): Current derivative of length N-1

        Returns:
            np.array: N-2 x N-2
        """
        eps = pow(10,-6)
        vec = 1.0/np.sqrt(pow(self.d_mat[:-1,:-1] @ deriv_curr,2) + eps)
        return np.diag(vec)

    def make_ln_mat(self, en_mat : np.array) -> np.array:
        """Diffusivity term

        Args:
            en_mat (np.array): Result from make_en_mat

        Returns:
            np.array: N-1 x N-1
        """
        return self.dx * np.transpose(self.d_mat[:-1,:-1]) @ en_mat @ self.d_mat[:-1,:-1]

    def make_gn_vec(self, deriv_curr : np.array, data : np.array, alpha : float, ln_mat : np.array) -> np.array:
        """Negative right hand side of linear problem

        Args:
            deriv_curr (np.array): Current derivative of size N-1
            data (np.array): Data of size N
            alpha (float): Regularization parameter
            ln_mat (np.array): Diffusivity term from make_ln_mat

        Returns:
            np.array: Vector of length N-1
        """
        return self.a_mat_t @ self.a_mat @ deriv_curr - self.a_mat_t @ (data - data[0])[1:] + alpha * ln_mat @ deriv_curr
    
    def make_hn_mat(self, alpha : float, ln_mat : np.array) -> np.array:
        """Matrix in linear problem

        Args:
            alpha (float): Regularization parameter
            ln_mat (np.array): Diffusivity term from make_ln_mat

        Returns:
            np.array: N-1 x N-1
        """
        return self.a_mat_t @ self.a_mat + alpha * ln_mat
    
    def get_deriv_tvr_update(self, data : np.array, deriv_curr : np.array, alpha : float) -> np.array:
        """Get the TVR update

        Args:
            data (np.array): Data of size N
            deriv_curr (np.array): Current deriv of size N-1
            alpha (float): Regularization parameter

        Returns:
            np.array: Update vector of size N-1
        """

        n = len(data)
    
        en_mat = self.make_en_mat(
            deriv_curr=deriv_curr
            )

        ln_mat = self.make_ln_mat(
            en_mat=en_mat
            )

        hn_mat = self.make_hn_mat(
            alpha=alpha,
            ln_mat=ln_mat
            )

        gn_vec = self.make_gn_vec(
            deriv_curr=deriv_curr,
            data=data,
            alpha=alpha,
            ln_mat=ln_mat
            )

        return solve(hn_mat, -gn_vec)

    def get_deriv_tvr(self, 
        data : np.array, 
        deriv_guess : np.array, 
        alpha : float,
        no_opt_steps : int,
        return_progress : bool = False, 
        return_interval : int = 1
        ) -> Tuple[np.array,np.array]:
        """Get derivative via TVR over optimization steps

        Args:
            data (np.array): Data of size N
            deriv_guess (np.array): Guess for derivative of size N-1
            alpha (float): Regularization parameter
            no_opt_steps (int): No. opt steps to run
            return_progress (bool, optional): True to return derivative progress during optimization. Defaults to False.
            return_interval (int, optional): Interval at which to store derivative if returning. Defaults to 1.

        Returns:
            Tuple[np.array,np.array]: First is the final derivative of size N-1, second is the stored derivatives if return_progress=True of size no_opt_steps+1 x N-1, else [].
        """

        deriv_curr = deriv_guess

        if return_progress:
            deriv_st = np.full((no_opt_steps+1, len(deriv_guess)), 0)
        else:
            deriv_st = np.array([])

        for opt_step in range(0,no_opt_steps):
            update = self.get_deriv_tvr_update(
                data=data,
                deriv_curr=deriv_curr,
                alpha=alpha
                )

            deriv_curr += update

            if return_progress:
                if opt_step % return_interval == 0:
                    deriv_st[int(opt_step / return_interval)] = deriv_curr

        return (np.array(deriv_curr), deriv_st)