import torch
import cv2
from torchstain.normalizers.he_normalizer import HENormalizer
from torchstain.utils import cov, percentile

"""
Source code ported from: https://github.com/schaugf/HEnorm_python
Original implementation: https://github.com/mitkovetta/staining-normalization
"""
class TorchMacenkoNormalizer(HENormalizer):
    def __init__(self):
        super().__init__()

        self.HERef = torch.tensor([[0.5626, 0.2159],
                                   [0.7201, 0.8012],
                                   [0.4062, 0.5581]])
        self.maxCRef = torch.tensor([1.9705, 1.0308])

    def __convert_rgb2od(self, I, Io, beta):
        I = I.permute(1, 2, 0)

        # calculate optical density
        OD = -torch.log((I.reshape((-1, I.shape[-1])).float()+1)/Io)

        # remove transparent pixels
        ODhat = OD[~torch.any(OD < beta, dim=1)]

        return OD, ODhat

    def __find_HE(self, ODhat, eigvecs, alpha):
        # project on the plane spanned by the eigenvectors corresponding to the two
        # largest eigenvalues
        That = torch.matmul(ODhat, eigvecs)
        phi = torch.atan2(That[:, 1], That[:, 0])

        minPhi = torch.tensor(percentile(phi, alpha))
        maxPhi = torch.tensor(percentile(phi, 100 - alpha))

        vMin = torch.matmul(eigvecs, torch.stack((torch.cos(minPhi), torch.sin(minPhi))).T).unsqueeze(1)
        vMax = torch.matmul(eigvecs, torch.stack((torch.cos(maxPhi), torch.sin(maxPhi))).T).unsqueeze(1)

        # a heuristic to make the vector corresponding to hematoxylin first and the
        # one corresponding to eosin second
        if vMin[0] > vMax[0]:
            HE = torch.cat((vMin, vMax), dim=1)

        else:
            HE = torch.cat((vMax, vMin), dim=1)

        return HE

    def __find_concentration(self, OD, HE):
        # rows correspond to channels (RGB), columns to OD values
        Y = OD.T

        # determine concentrations of the individual stains
        return torch.lstsq(Y, HE)[0][:2]

    def __compute_matrices(self, I, Io, alpha, beta):
        OD, ODhat = self.__convert_rgb2od(I, Io=Io, beta=beta)

        # compute eigenvectors
        _, eigvecs = torch.symeig(cov(ODhat.T), eigenvectors=True)
        eigvecs = eigvecs[:, [1, 2]]

        HE = self.__find_HE(ODhat, eigvecs, alpha)

        C = self.__find_concentration(OD, HE)
        maxC = torch.tensor([percentile(C[0, :], 99), percentile(C[1, :], 99)])

        return HE, C, maxC

    def fit(self, I, Io=240, alpha=1, beta=0.15):
        HE, _, maxC = self.__compute_matrices(I, Io, alpha, beta)

        self.HERef = HE
        self.maxCRef = maxC

    def normalize(self, I, Io=240, alpha=1, beta=0.15, stains=True):
        ''' Normalize staining appearence of H&E stained images

        Example use:
            see test.py

        Input:
            I: RGB input image: tensor of shape [C, H, W] and type uint8
            Io: (optional) transmitted light intensity
            alpha: percentile
            beta: transparency threshold
            stains: if true, return also H & E components

        Output:
            Inorm: normalized image
            H: hematoxylin image
            E: eosin image

        Reference:
            A method for normalizing histology slides for quantitative analysis. M.
            Macenko et al., ISBI 2009
        '''
        c, h, w = I.shape

        HE, C, maxC = self.__compute_matrices(I, Io, alpha, beta)

        # normalize stain concentrations
        C *= (self.maxCRef / maxC).unsqueeze(-1)

        # recreate the image using reference mixing matrix
        Inorm = Io * torch.exp(-torch.matmul(self.HERef, C))
        Inorm[Inorm > 255] = 254
        Inorm = Inorm.T.reshape(h, w, c).int()

        H, E = None, None

        if stains:
            H = torch.mul(Io, torch.exp(torch.matmul(-self.HERef[:, 0].unsqueeze(-1), C[0, :].unsqueeze(0))))
            H[H > 255] = 254
            H = H.T.reshape(h, w, c).int()

            E = torch.mul(Io, torch.exp(torch.matmul(-self.HERef[:, 1].unsqueeze(-1), C[1, :].unsqueeze(0))))
            E[E > 255] = 254
            E = E.T.reshape(h, w, c).int()

        return Inorm, H, E
