# elatentlpips: https://github.com/mingukkang/elatentlpips
# The CC-BY-NC license
# See license file or visit https://github.com/mingukkang/elatentlpips for details

# elatentlpips/elatentlpips.py


from __future__ import absolute_import

from torch.autograd import Variable
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

import elatentlpips
from elatentlpips.vgg16 import CalibratedLatentVGG16BN
from elatentlpips.ada_aug import AdaAugment


ada_augpipe = {
    'b':      dict(xflip=1, rotate90=1, xint=1),
    'g':      dict(scale=1, rotate=1, aniso=1, xfrac=1),
    'c':      dict(brightness=1, contrast=1, saturation=1),
    'o':      dict(cutout=1),
    'co':     dict(brightness=1, contrast=1, saturation=1, cutout=1),
    'bg':     dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1),
    'bgc':    dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, saturation=1),
    'bgco':   dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, saturation=1, cutout=1),
}


def print_network(net):
    num_params = 0
    for param in net.parameters():
        num_params += param.numel()
    print('Network',net)
    print('Total number of parameters: %d' % num_params)

def spatial_average(in_tens, keepdim=True):
    return in_tens.mean([2,3],keepdim=keepdim)

def upsample(in_tens, out_HW=(64,64)): # assumes scale factor is same for H and W
    in_H, in_W = in_tens.shape[2], in_tens.shape[3]
    return nn.Upsample(size=out_HW, mode='bilinear', align_corners=False)(in_tens)

# Latent version of learned perceptual metric (E-LatentLPIPS)
class ELatentLPIPS(nn.Module):
    def __init__(self, pretrained=True, net='vgg16', elatentlpips=True, spatial=False, 
        pnet_rand=False, pnet_tune=False, use_dropout=True, eval_mode=True,
        verbose=True, encoder="sd3", augment="bg"):
        """ Initializes a perceptual loss torch.nn.Module

        Parameters
        ---------------------------------
        pretrained : bool
            This flag controls the linear layers, which are only in effect when elatentlpips=True above
            [True] means linear layers are calibrated with human perceptual judgments
            [False] means linear layers are randomly initialized
        net : str
            ['vgg16'] is the base/trunk networks available
        elatentlpips : bool
            This flag activates ensembling of latent perceptual loss computation.
        pnet_rand : bool
            [False] means trunk loaded with ImageNet classification weights
            [True] means randomly initialized trunk
        encoder : str
            Specifies the type of latent space generated by the encoder.
            Available options: ['sd15', 'sd21', 'sdxl', 'sd3', 'flux']. Default is ['sd3'].
        augment : str
            Types of differentiable augmentations applied to the input.
            Available options: ['b', 'g', 'c', 'o', 'co', 'bg', 'bgc', 'bgco']. Default is ['bg'].

        The following parameters should only be changed if training the network

        pnet_tune
            [False] keep base/trunk frozen
            [True] tune the base/trunk network
        use_dropout : bool
            [True] to use dropout when training linear layers
            [False] for no dropout when training linear layers
        eval_mode : bool
            [True] is for test mode (default)
            [False] is for training mode
        """

        super(ELatentLPIPS, self).__init__()
        if(verbose):
            print('Setting up [%s] perceptual loss: trunk [%s], spatial [%s]'%
                 ('ELatentLPIPS' if elatentlpips else 'baseline', net, 'on' if spatial else 'off'))

        self.pnet_type = net
        self.pnet_tune = pnet_tune
        self.pnet_rand = pnet_rand
        self.spatial = spatial
        self.elatentlpips = elatentlpips # false means baseline of just averaging all layers
        self.encoder = encoder
        self.augment = augment

        if(self.pnet_type in ['vgg','vgg16']):
            net_type = CalibratedLatentVGG16BN
            self.chns = [64,128,256,512,512]
        else:
            raise NotImplementedError('Network %s not implemented'%net)
        
        self.L = len(self.chns)

        if self.encoder == "sd15":
            num_latent_channels = 4
        elif self.encoder == "sd21":
            num_latent_channels = 4
        elif self.encoder == "sdxl":
            num_latent_channels = 4
        elif self.encoder == "sd3":
            num_latent_channels = 16
        elif self.encoder == "flux":
            num_latent_channels = 16
    
        self.net = net_type(num_latent_channels, self.encoder, pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)

        if(elatentlpips):
            self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
            self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
            self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
            self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
            self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
            self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4]
            self.lins = nn.ModuleList(self.lins)

            if(pretrained):
                if self.encoder in ["sd15", "sd21", "sdxl", "sd3", "flux"]:
                    url_path = f'https://huggingface.co/Mingguksky/elatentlpips/resolve/main/elatentlpips_ckpt/{self.encoder}_latest_{self.pnet_type}_tuned.pth'
                    if not os.path.exists('./ckpt'):
                        os.makedirs('./ckpt', exist_ok=True)

                    if not os.path.exists(f"./ckpt/{self.encoder}_latest_{self.pnet_type}_tuned.pth"):
                        torch.hub.download_url_to_file(url_path, f"./ckpt/{self.encoder}_latest_{self.pnet_type}_tuned.pth")
                    ckpt = torch.load(f"./ckpt/{self.encoder}_latest_{self.pnet_type}_tuned.pth")
                else:
                    raise NotImplementedError('Encoder %s not implemented' % self.encoder)

                if(verbose):
                    print(f"Loading linear heads from: {url_path}")

                self.load_state_dict(ckpt, strict=True)            

        if augment is not None:
            self.augment = AdaAugment(**ada_augpipe[augment]).train().requires_grad_(False)
            self.augment.p = torch.tensor(1.0)
        
        if(eval_mode):
            self.eval()

    def forward(self, in0, in1, retPerLayer=False, normalize=False, ensembling=True, add_l1_loss=True):
        if normalize:
            # Both in0 and in1 should be the outputs of the encoder model, and they should not be normalized.
            if self.encoder in ["sd15", "sd21"]:
                in0 = in0 * 0.18215
                in1 = in1 * 0.18215
            elif self.encoder == "sdxl":
                in0 = in0 * 0.13025
                in1 = in1 * 0.13025
            elif self.encoder == "sd3":
                in0 = (in0 - 0.0609) * 1.5305
                in1 = (in1 - 0.0609) * 1.5305
            elif self.encoder == "flux":
                in0 = (in0 - 0.1159) * 0.3611
                in1 = (in1 - 0.1159) * 0.3611

        if add_l1_loss:
            l1_loss = F.l1_loss(in0, in1, reduction='none').mean(dim=(1, 2, 3))[:, None, None, None]

        if ensembling and self.augment is not None:
            in0, in1 = self.augment(in0, in1)
        elif ensembling and self.augment is None:
            raise ValueError("Augmentation is not enabled.")

        # v0.0 - original release had a bug, where input was not scaled
        outs0, outs1 = self.net.forward(in0), self.net.forward(in1)
        feats0, feats1, diffs = {}, {}, {}

        for kk in range(self.L):
            feats0[kk], feats1[kk] = elatentlpips.normalize_tensor(outs0[kk]), elatentlpips.normalize_tensor(outs1[kk])
            diffs[kk] = (feats0[kk]-feats1[kk])**2

        if(self.elatentlpips):
            if(self.spatial):
                res = [upsample(self.lins[kk](diffs[kk]), out_HW=in0.shape[2:]) for kk in range(self.L)]
            else:
                res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)]
        else:
            if(self.spatial):
                res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_HW=in0.shape[2:]) for kk in range(self.L)]
            else:
                res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)]

        val = 0
        for l in range(self.L):
            val += res[l]
        
        if add_l1_loss:
            val += l1_loss

        if(retPerLayer):
            return (val, res)
        else:
            return val

class NetLinLayer(nn.Module):
    ''' A single linear layer which does a 1x1 conv '''
    def __init__(self, chn_in, chn_out=1, use_dropout=False):
        super(NetLinLayer, self).__init__()

        layers = [nn.Dropout(),] if(use_dropout) else []
        layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),]
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

class Dist2LogitLayer(nn.Module):
    ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) '''
    def __init__(self, chn_mid=32, use_sigmoid=True):
        super(Dist2LogitLayer, self).__init__()

        layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),]
        layers += [nn.LeakyReLU(0.2,True),]
        layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),]
        layers += [nn.LeakyReLU(0.2,True),]
        layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),]
        if(use_sigmoid):
            layers += [nn.Sigmoid(),]
        self.model = nn.Sequential(*layers)

    def forward(self,d0,d1,eps=0.1):
        return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1))

class BCERankingLoss(nn.Module):
    def __init__(self, chn_mid=32):
        super(BCERankingLoss, self).__init__()
        self.net = Dist2LogitLayer(chn_mid=chn_mid)
        # self.parameters = list(self.net.parameters())
        self.loss = nn.BCELoss()

    def forward(self, d0, d1, judge):
        per = (judge+1.)/2.
        self.logit = self.net.forward(d0,d1)
        return self.loss(self.logit, per)

# L2, DSSIM metrics
class FakeNet(nn.Module):
    def __init__(self, use_gpu=True, colorspace='Lab'):
        super(FakeNet, self).__init__()
        self.use_gpu = use_gpu
        self.colorspace = colorspace

class L2(FakeNet):
    def forward(self, in0, in1, retPerLayer=None):
        assert(in0.size()[0]==1) # currently only supports batchSize 1

        if(self.colorspace=='RGB'):
            (N,C,X,Y) = in0.size()
            value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N)
            return value
        elif(self.colorspace=='Lab'):
            value = elatentlpips.l2(elatentlpips.tensor2np(elatentlpips.tensor2tensorlab(in0.data,to_norm=False)), 
                elatentlpips.tensor2np(elatentlpips.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
            ret_var = Variable( torch.Tensor((value,) ) )
            if(self.use_gpu):
                ret_var = ret_var.cuda()
            return ret_var

class DSSIM(FakeNet):
    def forward(self, in0, in1, retPerLayer=None):
        assert(in0.size()[0]==1) # currently only supports batchSize 1

        if(self.colorspace=='RGB'):
            value = elatentlpips.dssim(1.*elatentlpips.tensor2im(in0.data), 1.*elatentlpips.tensor2im(in1.data), range=255.).astype('float')
        elif(self.colorspace=='Lab'):
            value = elatentlpips.dssim(elatentlpips.tensor2np(elatentlpips.tensor2tensorlab(in0.data,to_norm=False)), 
                elatentlpips.tensor2np(elatentlpips.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
        ret_var = Variable( torch.Tensor((value,) ) )
        if(self.use_gpu):
            ret_var = ret_var.cuda()
        return ret_var