import numpy as np
from sklearn.metrics import accuracy_score
from numba.experimental import jitclass


class Loss(object):
    def loss(self, y_true, y_pred):
        return NotImplementedError()

    def gradient(self, y, y_pred):
        raise NotImplementedError()

    def acc(self, y, y_pred):
        return 0


@jitclass()
class SquareLoss(Loss):
    def __init__(self): pass

    def loss(self, y, y_pred):
        return np.sum(0.5 * np.power((y - y_pred), 2))

    def gradient(self, y, y_pred):
        return -(y - y_pred)

    def hess(self, y, y_pred):
        return np.array([1 for i in range(len(y))])


@jitclass()
class CrossEntropy(Loss):
    def __init__(self): pass

    def loss(self, y, p):
        p = np.clip(p, 1e-15, 1 - 1e-15)
        return - y * np.log(p) - (1 - y) * np.log(1 - p)

    def acc(self, y, p):
        return accuracy_score(np.argmax(y, axis=1), np.argmax(p, axis=1))

    def gradient(self, y, p):
        p = np.clip(p, 1e-15, 1 - 1e-15)
        return - (y / p) + (1 - y) / (1 - p)


@jitclass()
class MSE(Loss):
    def __init__(self): pass

    def loss(self, y, y_pred):
        return np.mean((y - y_pred) ** 2)

    def gradient(self, y, y_pred):
        return -(2 / len(y) * (y - y_pred))