import torch
import torch.nn.functional as F

from cellseg_models_pytorch.utils import tensor_one_hot

from ..weighted_base_loss import WeightedBaseLoss

__all__ = ["DiceLoss"]


class DiceLoss(WeightedBaseLoss):
    def __init__(
        self,
        apply_sd: bool = False,
        apply_ls: bool = False,
        apply_svls: bool = False,
        edge_weight: float = None,
        class_weights: torch.Tensor = None,
        **kwargs,
    ) -> None:
        """Sørensen-Dice Coefficient Loss.

        Optionally applies weights at the object edges and classes.

        Parameters
        ----------
            apply_sd : bool, default=False
                If True, Spectral decoupling regularization will be applied  to the
                loss matrix.
            apply_ls : bool, default=False
                If True, Label smoothing will be applied to the target.
            apply_svls : bool, default=False
                If True, spatially varying label smoothing will be applied to the target
            edge_weight : float, default=none
                Weight that is added to object borders.
            class_weights : torch.Tensor, default=None
                Class weights. A tensor of shape (n_classes,).
        """
        super().__init__(apply_sd, apply_ls, apply_svls, class_weights, edge_weight)
        self.eps = 1e-8

    def forward(
        self,
        yhat: torch.Tensor,
        target: torch.Tensor,
        target_weight: torch.Tensor = None,
        **kwargs,
    ) -> torch.Tensor:
        """Compute the DICE coefficient.

        Parameters
        ----------
            yhat : torch.Tensor
                The prediction map. Shape (B, C, H, W).
            target : torch.Tensor
                the ground truth annotations. Shape (B, H, W).
            target_weight : torch.Tensor, default=None
                The edge weight map. Shape (B, H, W).

        Returns
        -------
            torch.Tensor:
                Computed DICE loss (scalar).
        """
        yhat_soft = F.softmax(yhat, dim=1)
        num_classes = yhat.shape[1]
        target_one_hot = tensor_one_hot(target, n_classes=num_classes)
        assert target_one_hot.shape == yhat.shape

        if self.apply_svls:
            target_one_hot = self.apply_svls_to_target(
                target_one_hot, num_classes, **kwargs
            )

        if self.apply_ls:
            target_one_hot = self.apply_ls_to_target(
                target_one_hot, num_classes, **kwargs
            )

        intersection = torch.sum(yhat_soft * target_one_hot, 1)
        union = torch.sum(yhat_soft + target_one_hot, 1)
        dice = 2.0 * intersection / union.clamp_min(self.eps)  # (B, H, W)

        if self.apply_sd:
            dice = self.apply_spectral_decouple(dice, yhat)

        if self.class_weights is not None:
            dice = self.apply_class_weights(dice, target)

        if self.edge_weight is not None:
            dice = self.apply_edge_weights(dice, target_weight)

        return torch.mean(1.0 - dice)
