# Copyright 2021 Albert Garcia
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
DSET extended from DSNT (soft-argmax) operations for use in PyTorch computation graphs.
"""

__version__ = "0.1.2"
__author__ = 'Albert Garcia'
__credits__ = 'SnT - Interdisciplinary Centre for Security, Reliability and Trust'


from functools import reduce
from operator import mul

import torch
import torch.nn.functional

import numpy as np
from scipy.stats import multivariate_normal
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse


def linear_expectation(probs, values):
    assert(len(values) == probs.ndimension() - 2)
    expectation = []
    for i in range(2, probs.ndimension()):
        # Marginalise probabilities
        marg = probs
        for j in range(probs.ndimension() - 1, 1, -1):
            if i != j:
                marg = marg.sum(j, keepdim=False)
        # Calculate expectation along axis `i`
        expectation.append((marg * values[len(expectation)]).sum(-1, keepdim=False))
    return torch.stack(expectation, -1)


def normalized_linspace(length, dtype=None, device=None):
    """Generate a vector with values ranging from -1 to 1.

    Note that the values correspond to the "centre" of each cell, so
    -1 and 1 are always conceptually outside the bounds of the vector.
    For example, if length = 4, the following vector is generated:

    ```text
     [ -0.75, -0.25,  0.25,  0.75 ]
     ^              ^             ^
    -1              0             1
    ```

    Args:
        length: The length of the vector

    Returns:
        The generated vector
    """
    if isinstance(length, torch.Tensor):
        length = length.to(device, dtype)
    first = -(length - 1.0) / length
    return torch.arange(length, dtype=dtype, device=device) * (2.0 / length) + first


def soft_argmax(heatmaps, normalized_coordinates=True):
    if normalized_coordinates:
        values = [normalized_linspace(d, dtype=heatmaps.dtype, device=heatmaps.device)
                  for d in heatmaps.size()[2:]]
    else:
        values = [torch.arange(0, d, dtype=heatmaps.dtype, device=heatmaps.device)
                  for d in heatmaps.size()[2:]]
    coords = linear_expectation(heatmaps, values)
    # We flip the tensor like this instead of using `coords.flip(-1)` because aten::flip is not yet
    # supported by the ONNX exporter.
    coords = torch.cat(tuple(reversed(coords.split(1, -1))), -1)
    return coords


def dsnt(heatmaps, **kwargs):
    """Differentiable spatial to numerical transform.

    Args:
        heatmaps (torch.Tensor): Spatial representation of locations

    Returns:
        Numerical coordinates corresponding to the locations in the heatmaps.
    """
    return soft_argmax(heatmaps, **kwargs)


def sharpen_heatmaps(heatmaps, alpha):
    """Sharpen heatmaps by increasing the contrast between high and low probabilities.

    Example:
        Approximate the mode of heatmaps using the approach described by Equation 1 of
        "FlowCap: 2D Human Pose from Optical Flow" by Romero et al.)::

            coords = soft_argmax(sharpen_heatmaps(heatmaps, alpha=6))

    Args:
        heatmaps (torch.Tensor): Heatmaps generated by the model
        alpha (float): Sharpness factor. When ``alpha == 1``, the heatmaps will be unchanged. Use
        ``alpha > 1`` to actually sharpen the heatmaps.

    Returns:
        The sharpened heatmaps.
    """
    sharpened_heatmaps = heatmaps ** alpha
    sharpened_heatmaps /= sharpened_heatmaps.flatten(2).sum(-1)
    return sharpened_heatmaps


def flat_softmax(inp):
    """Compute the softmax with all but the first two tensor dimensions combined."""

    orig_size = inp.size()
    flat = inp.view(-1, reduce(mul, orig_size[2:]))
    flat = torch.nn.functional.softmax(flat, -1)
    return flat.view(*orig_size)


def euclidean_losses(actual, target):
    """Calculate the Euclidean losses for multi-point samples.

    Each sample must contain `n` points, each with `d` dimensions. For example,
    in the MPII human pose estimation task n=16 (16 joint locations) and
    d=2 (locations are 2D).

    Args:
        actual (Tensor): Predictions (B x L x D)
        target (Tensor): Ground truth target (B x L x D)


    Returns:
        Tensor: Losses (B x L)
    """
    assert actual.size() == target.size(), 'input tensors must have the same size'
    return torch.norm(actual - target, p=2, dim=-1, keepdim=False)


def l1_losses(actual, target):
    """Calculate the average L1 losses for multi-point samples.

    Args:
        actual (Tensor): Predictions (B x L x D)
        target (Tensor): Ground truth target (B x L x D)

    Returns:
        Tensor: Losses (B x L)
    """
    assert actual.size() == target.size(), 'input tensors must have the same size'
    return torch.nn.functional.l1_loss(actual, target, reduction='none').mean(-1)


def mse_losses(actual, target):
    """Calculate the average squared L2 losses for multi-point samples.

    Args:
        actual (Tensor): Predictions (B x L x D)
        target (Tensor): Ground truth target (B x L x D)

    Returns:
        Tensor: Losses (B x L)
    """
    assert actual.size() == target.size(), 'input tensors must have the same size'
    return torch.nn.functional.mse_loss(actual, target, reduction='none').mean(-1)


def make_gauss(means, size, sigma, normalize=True):
    """Draw Gaussians.

    This function is differential with respect to means.

    Note on ordering: `size` expects [..., depth, height, width], whereas
    `means` expects x, y, z, ...

    Args:
        means: coordinates containing the Gaussian means (units: normalized coordinates)
        size: size of the generated images (units: pixels)
        sigma: standard deviation of the Gaussian (units: pixels)
        normalize: when set to True, the returned Gaussians will be normalized
    """

    dim_range = range(-1, -(len(size) + 1), -1)
    coords_list = [normalized_linspace(s, dtype=means.dtype, device=means.device)
                   for s in reversed(size)]

    # PDF = exp(-(x - \mu)^2 / (2 \sigma^2))

    # dists <- (x - \mu)^2
    dists = [(x - mean) ** 2 for x, mean in zip(coords_list, means.split(1, -1))]

    # ks <- -1 / (2 \sigma^2)
    stddevs = [2 * sigma / s for s in reversed(size)]
    ks = [-0.5 * (1 / stddev) ** 2 for stddev in stddevs]

    exps = [(dist * k).exp() for k, dist in zip(ks, dists)]

    # Combine dimensions of the Gaussian
    gauss = reduce(mul, [
        reduce(lambda t, d: t.unsqueeze(d), filter(lambda d: d != dim, dim_range), dist)
        for dim, dist in zip(dim_range, exps)
    ])

    if not normalize:
        return gauss

    # Normalize the Gaussians
    val_sum = reduce(lambda t, dim: t.sum(dim, keepdim=True), dim_range, gauss) + 1e-24
    return gauss / val_sum


def average_loss(losses, mask=None):
    """Calculate the average of per-location losses.

    Args:
        losses (Tensor): Predictions (B x L)
        mask (Tensor, optional): Mask of points to include in the loss calculation
            (B x L), defaults to including everything
    """

    if mask is not None:
        assert mask.size() == losses.size(), 'mask must be the same size as losses'
        losses = losses * mask
        denom = mask.sum()
    else:
        denom = losses.numel()

    # Prevent division by zero
    if isinstance(denom, int):
        denom = max(denom, 1)
    else:
        denom = denom.clamp(1)

    return losses.sum() / denom


def _kl(p, q, ndims):
    eps = 1e-24
    unsummed_kl = p * ((p + eps).log() - (q + eps).log())
    kl_values = reduce(lambda t, _: t.sum(-1, keepdim=False), range(ndims), unsummed_kl)
    return kl_values


def _js(p, q, ndims):
    m = 0.5 * (p + q)
    return 0.5 * _kl(p, m, ndims) + 0.5 * _kl(q, m, ndims)


def _divergence_reg_losses(heatmaps, mu_t, sigma_t, divergence):
    ndims = mu_t.size(-1)
    assert heatmaps.dim() == ndims + 2, 'expected heatmaps to be a {}D tensor'.format(ndims + 2)
    assert heatmaps.size()[:-ndims] == mu_t.size()[:-1]

    gauss = make_gauss(mu_t, heatmaps.size()[2:], sigma_t)
    divergences = divergence(heatmaps, gauss, ndims)
    return divergences


def kl_reg_losses(heatmaps, mu_t, sigma_t):
    """Calculate Kullback-Leibler divergences between heatmaps and target Gaussians.

    Args:
        heatmaps (torch.Tensor): Heatmaps generated by the model
        mu_t (torch.Tensor): Centers of the target Gaussians (in normalized units)
        sigma_t (float): Standard deviation of the target Gaussians (in pixels)

    Returns:
        Per-location KL divergences.
    """

    return _divergence_reg_losses(heatmaps, mu_t, sigma_t, _kl)


def js_reg_losses(heatmaps, mu_t, sigma_t):
    """Calculate Jensen-Shannon divergences between heatmaps and target Gaussians.

    Args:
        heatmaps (torch.Tensor): Heatmaps generated by the model
        mu_t (torch.Tensor): Centers of the target Gaussians (in normalized units)
        sigma_t (float): Standard deviation of the target Gaussians (in pixels)

    Returns:
        Per-location JS divergences.
    """

    return _divergence_reg_losses(heatmaps, mu_t, sigma_t, _js)


def variance_reg_losses(heatmaps, sigma_t):
    """Calculate the loss between heatmap variances and target variance.

    Note that this is slightly different from the version used in the
    DSNT paper. This version uses pixel units for variance, which
    produces losses that are larger by a constant factor.

    Args:
        heatmaps (torch.Tensor): Heatmaps generated by the model
        sigma_t (float): Target standard deviation (in pixels)

    Returns:
        Per-location sum of square errors for variance.
    """

    # mu = E[X]
    values = [normalized_linspace(d, dtype=heatmaps.dtype, device=heatmaps.device)
              for d in heatmaps.size()[2:]]
    mu = linear_expectation(heatmaps, values)
    # var = E[(X - mu)^2]
    values = [(a - b.squeeze(0)) ** 2 for a, b in zip(values, mu.split(1, -1))]
    var = linear_expectation(heatmaps, values)


    heatmap_size = torch.tensor(list(heatmaps.size()[2:]), dtype=var.dtype, device=var.device)
    actual_variance = var * (heatmap_size / 2) ** 2
    target_variance = sigma_t ** 2
    sq_error = (actual_variance - target_variance) ** 2

    return sq_error.sum(-1, keepdim=False)


def normalized_to_pixel_coordinates(coords, size):
    """Convert from normalized coordinates to pixel coordinates.

    Args:
        coords: Coordinate tensor, where elements in the last dimension are ordered as (x, y, ...).
        size: Number of pixels in each spatial dimension, ordered as (..., height, width).

    Returns:
        `coords` in pixel coordinates.
    """
    if torch.is_tensor(coords):
        size = size.clone().flip(-1)
    return 0.5 * ((coords + 1) * size - 1)


def pixel_to_normalized_coordinates(coords, size):
    """Convert from pixel coordinates to normalized coordinates.

    Args:
        coords: Coordinate tensor, where elements in the last dimension are ordered as (x, y, ...).
        size: Number of pixels in each spatial dimension, ordered as (..., height, width).

    Returns:
        `coords` in normalized coordinates.
    """
    if torch.is_tensor(coords):
        size = size.clone().flip(-1)
    return ((2 * coords + 1) / size) - 1



#################################################### DSET Extension ###########################################################


def linear_expectation_covariance(heatmaps, aux_values):
    # Duplicate each vector (X-mu_x) and (Y-mu_y) into matrices
    # This is done since the covariance computation cannot be vectorized like in the variances computations
    aux_values[0] = aux_values[0].unsqueeze(-1).repeat(1, 1, 1, heatmaps.size()[-1])
    aux_values[1] = aux_values[1].unsqueeze(-2).repeat(1, 1, heatmaps.size()[-2], 1)

    # Do (X-mu_x) * (Y-mu_y)
    aux_values = aux_values[0] * aux_values[1]
    assert(aux_values.size() == heatmaps.size()),'The size of the heatmaps '+str(heatmaps.size())+' is not equal to those of the aux_values '+str(aux_values.size())
    return (heatmaps * aux_values).sum(dim=(-2,-1))


def compute_heatmaps_statistics(heatmaps, pixel_units = False):
    """Calculate the expectation, variance and covariance of each heatmap.

    Args:
        heatmaps (torch.Tensor): Heatmaps generated by the model with shape [B, E, H, W]
        pixel_units (bool, default = False): If True then returns all values in pixels otherwise return expectation in normalized coordinates while variance and covariance are returned in relative values.

    Returns:
        Expectetions, variances and covariances for each Ellipse E within each batch image B.
    """

    # mu = E[X]
    values = [normalized_linspace(d, dtype=heatmaps.dtype, device=heatmaps.device)
              for d in heatmaps.size()[2:]]

    mu = linear_expectation(heatmaps, values)

    aux_values = [(a - b.squeeze(0)) for a, b in zip(values, mu.split(1, -1))]

    # var = E[(X - mu)^2]
    var_values = [x ** 2 for x in aux_values]
    var = linear_expectation(heatmaps, var_values)

    # cov = E[(X - mu_x)*(Y - mu_y)]
    cov = linear_expectation_covariance(heatmaps, aux_values)

    # We flip the tensors like this instead of using `coords.flip(-1)` because aten::flip is not yet
    # supported by the ONNX exporter.
    mu = torch.cat(tuple(reversed(mu.split(1, -1))), -1)
    var = torch.cat(tuple(reversed(var.split(1,-1))), -1)
    
    # If we want the output in pixel units (not normalized/relative) then scale them with the heatmap size
    if pixel_units:
        heatmaps_size_array = list(heatmaps.size()[2:])
        heatmap_size = torch.tensor(heatmaps_size_array, dtype=var.dtype, device=var.device)
        heatmap_size_reversed = torch.tensor(heatmaps_size_array[::-1], dtype=var.dtype, device=var.device)
        var = var * (heatmap_size_reversed / 2) ** 2
        cov = cov * (heatmap_size[0] / 2) * (heatmap_size[1] / 2)
        mu = normalized_to_pixel_coordinates(mu, heatmap_size)

    return mu, var, cov


def construct_covariance_matrices(var, cov):
    """Construct all covariance matrices of a batch.

    Args:
        var (torch.Tensor): Variances computed from the heatmaps with shape [B, E, 2]
        cov (torch.Tensor): Covariances computed from the heatmaps with shape [B, E]

    Returns:
        Tensor containing all the covariances matrices with shape [B, E, 2, 2]
    """
    
    batches, locations = var.size()[0:2]
    all_sigmas = []
    
    for batch_num in range(batches):
        batch_sigmas = []
        
        for location_num in range(locations):
            sigma = torch.diag(var[batch_num, location_num, :])
            sigma[0,1] = sigma[1,0] = cov[batch_num, location_num]
            batch_sigmas.append(sigma)
        
        batch_sigmas =  torch.stack(batch_sigmas, dim=0)
        all_sigmas.append(batch_sigmas)
    
    return torch.stack(all_sigmas, dim=0)


def covariance_matrices_to_parameters(sigmas):
    """Compute the axes and rotation from all covariance matrices of a batch.

    Args:
        sigmas (torch.Tensor): Covariances matrices with shape [B, E, 2, 2]

    Returns:
        Tensor containing all the extracted parameters (axis_1, axis_2, theta [rads]) with shape [B, E, 3]
    """
    
    batches, locations = sigmas.size()[0:2]
    all_params = []
    
    for batch_num in range(batches):
        batch_params = []
        
        for location_num in range(locations):
            D, R = torch.eig(sigmas[batch_num, location_num, :, :], eigenvectors=True)
            axes = torch.sqrt(torch.abs(D))[:,0] # First column is the real part, second the imaginary part
            cos, sin = R[0,0], R[1,0]
            theta = torch.atan2(sin, cos).unsqueeze(0)
            batch_params.append( torch.cat( (axes, theta) ) )
        
        batch_params =  torch.stack(batch_params, dim=0)
        all_params.append(batch_params)
    
    return torch.stack(all_params, dim=0)
    

def plot_ellipse_from_covariance_matrix(center, sigma, original_heatmap, alpha_heatmap=1.0, heatmap_cmap='gray', color='r'):
    """Plot the ellipse derived from the center and the covariance matrix 'sigma' along with the original heatmap.
    If a batch of centers or a batch of sigmas is given, then only the first sample will be plotted.

    Args:
        center (torch.Tensor): Center tensor of the ellipse with shape [2]
        sigmas (torch.Tensor): Covariance matrix of the heatmap with shape [2, 2]
        original_heatmap (torch.Tensor): Heatmap with shape [H, W]
        alpha_heatmap (float): Float between 0 and 1 to apply transparency to the plotted heatmap
        heatmap_cmap (String): Color map to be used for the heatmap

    Returns:
        Matplotlib plot with the derived ellipse and the heatmap.

    """

    if isinstance(center, torch.Tensor):
        center = center.detach().numpy()
    if len(center.shape) == 3:
        center = center[0,0]
    if isinstance(sigma, np.ndarray):
        sigma = torch.Tensor(sigma)
    if len(sigma.shape) == 4:
        sigma = sigma[0,0]
    if len(sigma.shape) == 2:
        sigma = sigma.unsqueeze(0).unsqueeze(0)
    
    params = covariance_matrices_to_parameters(sigma)
#     print('Axes and theta:', params)
    angle_deg = np.rad2deg(params[0,0,2])
#     print('Theta (deg):', angle_deg.item())
    axes = params[0,0,0:2]
    plot_axes = plt.gca()
    e = Ellipse(xy=center, width=axes[0]*2, height=axes[1]*2,
                angle=angle_deg, edgecolor=color, linestyle='-',
                linewidth=2, fill=False)
    plot_axes.add_artist(e)

    if len(original_heatmap.shape) == 4:
        original_heatmap = original_heatmap[0,0]
    
    plt.scatter(*center, c=color)
    plt.imshow(original_heatmap, cmap=heatmap_cmap, alpha=alpha_heatmap)


def gaussian_heatmap(heatmap_size, mean, cov, normalize=True, return_tensor=True):
    """Generate a Gaussian heatmap with specific mean and covariance matrix.

    Args:
        heatmap_size (torch.Tensor): Heatmap size to be generated [H, W]
        mean (torch.Tensor): Mean values with shape [2]
        cov (torch.Tensor): Covariance matrix with shape [2, 2]
        normalize (bool, default = True): Whether to normalize the generated heatmap
        return_tensor (bool, default = True): Whether to return PyTorch Tensor instead of Numpy array

    Returns:
        Tensor containing all the Gaussian heatmaps with shape [B, E, H, W]
    """
    x = np.arange(0, heatmap_size[1], 1)
    y = np.arange(0, heatmap_size[0], 1)
    xx, yy = np.meshgrid(x, y)
    points = np.stack((xx, yy), axis=-1)
    pdf = multivariate_normal.pdf(points, mean=mean, cov=cov)
    if normalize: pdf = pdf / pdf.sum()
    if return_tensor: pdf = torch.from_numpy(pdf)
    return pdf


def generate_gaussian_heatmaps(heatmap_size, means, sigmas, normalize=True, return_tensor=True):
    """Generate all the Gaussian heatmaps from a batch of means and covariance matrices.

    Args:
        heatmap_size (torch.Tensor): Heatmap size to be generated [H, W]
        means (torch.Tensor): Means values with shape [B, E, 2]
        sigmas (torch.Tensor): Covariances matrices with shape [B, E, 2, 2]
        normalize (bool, default = True): Whether to normalize the generated heatmaps
        return_tensor (bool, default = True): Whether to return PyTorch Tensor instead of Numpy array

    Returns:
        Tensor containing all the Gaussian heatmaps with shape [B, E, H, W]
    """
    
    batches, locations = means.size()[0:2]
    all_heatmaps = []
    
    for batch_num in range(batches):
        batch_heatmaps = []
        
        for location_num in range(locations):
            heatmap = gaussian_heatmap(heatmap_size, means[batch_num, location_num], sigmas[batch_num, location_num], normalize=normalize, return_tensor=return_tensor)
            
            batch_heatmaps.append( heatmap )
        
        batch_heatmaps =  torch.stack(batch_heatmaps, dim=0)
        all_heatmaps.append( batch_heatmaps )
    
    return torch.stack(all_heatmaps, dim=0)


def normalized_cov_mtx_to_cov_mtx(cov, size):
    """Convert from normalized covariance matrix to pixel covariance matrix.

    Args:
        cov (torch.Tensor): Normalized covariance matrix with shape [2, 2]
        size (torch.Tensor): Target size to be scaled [H, W]

    Returns:
        Covariance tensor in pixel units with shape [2, 2]
    """
    scaling_mtx = (size.flip(0))*2
    scaling_mtx = scaling_mtx.unsqueeze(-1)
    scaling_mtx = scaling_mtx * scaling_mtx.transpose(0,1)
    return cov * scaling_mtx


def cov_mtx_to_normalized_cov_mtx(cov, size):
    """Convert from covariance matrix in pixels to normalized covariance matrix.

    Args:
        cov (torch.Tensor): Pixel covariance matrix with shape [2, 2]
        size (torch.Tensor): Target size to be considered for normalization [H, W]

    Returns:
        Covariance tensor in normalized units with shape [2, 2]
    """
    scaling_mtx = (size.flip(0))*2
    scaling_mtx = scaling_mtx.unsqueeze(-1)
    scaling_mtx = scaling_mtx * scaling_mtx.transpose(0,1)
    return cov / scaling_mtx

