#!/usr/bin/env python3
# -*- coding: utf-8 -*-

__author__ = "Christian Heider Lindbjerg"
__doc__ = r"""

           Created on 21/02/2020
           """

from typing import Union

import numpy
import torch

__all__ = [
    "standardise",
    "minus_one_one_unnormalise",
    "minus_one_one_normalise",
    "normalize_row",
]

from warg import Number


def normalize_row(x):
    """Normalizes a tensor per row."""
    return x / numpy.linalg.norm(x, axis=-1, keepdims=True)


def minus_one_one_unnormalise(
    x: Union[Number, torch.Tensor, numpy.ndarray], low: float = 0, high: float = 1
) -> torch.tensor:
    """description"""
    act_k = (high - low) / 2.0
    act_b = (high + low) / 2.0
    return act_k * x + act_b


def minus_one_one_normalise(
    x: Union[Number, torch.Tensor, numpy.ndarray], low: float = 0, high: float = 1
) -> torch.tensor:
    """description"""
    act_k_inv = 2.0 / (high - low)
    act_b = (high + low) / 2.0
    return act_k_inv * (x - act_b)


def standardise(x: torch.Tensor, eps: float = 1e-6) -> torch.tensor:
    """

    :param eps:
    :param x:
    :return:"""
    x -= x.mean()
    x /= x.std() + eps
    return x


if __name__ == "__main__":
    print(standardise(torch.ones(10)))
    print(standardise(torch.ones((10, 1))))
    print(standardise(torch.ones((1, 10))))

    print(standardise(torch.diag(torch.ones(3))))

    print(standardise(torch.ones((1, 10)) * torch.rand((1, 10))))

    print(standardise(torch.rand((1, 10))))

    print(minus_one_one_normalise(7, 0, 10))
    print(minus_one_one_unnormalise(0.4, 0, 10))
    print(minus_one_one_normalise(minus_one_one_unnormalise(3.4, 3, 4), 3, 4))
