# AUTOGENERATED! DO NOT EDIT! File to edit: ../00_core.ipynb.

# %% auto 0
__all__ = ['compute_mean_and_covariance', 'compute_mahalanobis_distance', 'OODMetric']

# %% ../00_core.ipynb 3
from fastcore.utils import *
import numpy as np

# %% ../00_core.ipynb 4
def compute_mean_and_covariance(
    embdedding: np.ndarray, # (n_sample, n_dim) n_sample - sample size of training set, n_dim - dimension of the embedding
    labels: np.ndarray, # (n_sample, ) n_sample - sample size of training set
) -> Tuple[np.ndarray, np.ndarray]: # Mean of dimension (n_dim, ) and Covariance matrix of dimension(n_dim, n_dim)
    """Computes class-specific means and a shared covariance matrix.
    """
    
    n_dim = embdedding.shape[1]
    class_ids = np.unique(labels)
    
    covariance = np.zeros((n_dim, n_dim)) 
    means = []

    def f(covariance, class_id):
        mask = np.expand_dims(labels == class_id, axis=-1) # to compute mean/variance use only those which belong to current class_id
        data = embdedding * mask
        mean = np.sum(data, axis=0) / np.sum(mask)
        diff = (data - mean) * mask
        covariance += np.matmul(diff.T, diff)
        return covariance, mean

    for class_id in class_ids:
        covariance, mean = f(covariance, class_id)
        means.append(mean)
        
    covariance = covariance / len(labels)
    return np.stack(means), covariance

# %% ../00_core.ipynb 5
def compute_mahalanobis_distance(
    embdedding: np.ndarray, # Embdedding of dimension (n_sample, n_dim)
    means: np.ndarray, # A matrix of size (num_classes, n_dim), where the ith row corresponds to the mean of the fitted Gaussian distribution for the i-th class.
    covariance: np.ndarray # The shared covariance matrix of the size (n_dim, n_dim)
) -> np.ndarray: # A matrix of size (n_sample, n_class) where the (i, j) element corresponds to the Mahalanobis distance between i-th sample to the j-th class Gaussian.
    """Computes Mahalanobis distance between the input and the fitted Guassians. The Mahalanobis distance (Mahalanobis, 1936) is defined as

    $$distance(x, mu, sigma) = sqrt((x-\mu)^T \sigma^{-1} (x-\mu))$$

    where `x` is a vector, `mu` is the mean vector for a Gaussian, and `sigma` is
    the covariance matrix. We compute the distance for all examples in `embdedding`,
    and across all classes in `means`.

    Note that this function technically computes the squared Mahalanobis distance
    """
    
    covariance_inv = np.linalg.pinv(covariance)
    maha_distances = []

    def maha_dist(x, mean):
        # NOTE: This computes the squared Mahalanobis distance.
        diff = x - mean
        return np.einsum("i, ij, j->", diff, covariance_inv, diff)

    for x in embdedding:
        arr = []
        for mean in means:
            arr.append(maha_dist(x, mean))
        arr = np.stack(arr)
        maha_distances.append(arr)

    return np.stack(maha_distances)

# %% ../00_core.ipynb 6
class OODMetric:
    """OOD Metric Class that calculates the OOD scores for a batch of input embeddings.
    Initialises the class by fitting the class conditional gaussian using training data
    and the class independent gaussian using training data.
    """

    def __init__(self,
                 train_embdedding: np.ndarray, # An array of size (n_sample, n_dim) where n_sample is the sample size of training set, n_dim is the dimension of the embedding.
                 train_labels: np.ndarray # An array of size (n_train_sample, )
                ):
        self.means, self.covariance = compute_mean_and_covariance(train_embdedding, train_labels)
        self.means_bg, self.covariance_bg = compute_mean_and_covariance(train_embdedding, np.zeros_like(train_labels))

# %% ../00_core.ipynb 7
@patch
def compute_rmd(
    self:OODMetric,
    embdedding: np.ndarray # An array of size (n_sample, n_dim), where n_sample is the sample size of the test set, and n_dim is the size of the embeddings.
) -> np.ndarray:  # An array of size (n_sample, ) where the ith element corresponds to the ood score of the ith data point.
    """This function computes the OOD score using the mahalanobis distance
    """
    
    distances = compute_mahalanobis_distance(embdedding, self.means, self.covariance)
    distances_bg = compute_mahalanobis_distance(embdedding, self.means_bg, self.covariance_bg)

    rmaha_distances = np.min(distances, axis=-1) - distances_bg[:, 0]
    return rmaha_distances
