# %%

from jjy.framework.cuda import xp as np


class OptimizerType:

    def __init__(self, **kwargs):
        if "trainable_param_keys" in kwargs:
            self.trainable_param_keys = kwargs["trainable_param_keys"]
        else:
            self.trainable_param_keys = None


    def get_trainable_param_keys(self, param_keys):
        if self.trainable_param_keys is None:
            return param_keys
        else:
            return self.trainable_param_keys


class SGD(OptimizerType):
    """확률적 경사 하강법（Stochastic Gradient Descent）"""

    def __init__(self, lr=0.01, **kwargs):
        super().__init__(**kwargs)

        self.lr = lr

    def update(self, params, grads):
        for key in self.get_trainable_param_keys(param_keys=params.keys()):
            params[key] -= self.lr * grads[key]


class Momentum(OptimizerType):
    """모멘텀 SGD"""

    def __init__(self, lr=0.01, momentum=0.9, **kwargs):
        super().__init__(**kwargs)
        self.lr = lr
        self.momentum = momentum
        self.v = None

    def update(self, params, grads):
        if self.v is None:
            self.v = {}
            for key, val in params.items():
                self.v[key] = np.zeros_like(val)

        for key in self.get_trainable_param_keys(param_keys=params.keys()):
            self.v[key] = self.momentum * self.v[key] - self.lr * grads[key]
            params[key] += self.v[key]


class Nesterov(OptimizerType):
    """Nesterov's Accelerated Gradient (http://arxiv.org/abs/1212.0901)"""

    # NAG는 모멘텀에서 한 단계 발전한 방법이다. (http://newsight.tistory.com/224)

    def __init__(self, lr=0.01, momentum=0.9, **kwargs):
        super().__init__(**kwargs)
        self.lr = lr
        self.momentum = momentum
        self.v = None

    def update(self, params, grads):
        if self.v is None:
            self.v = {}
            for key, val in params.items():
                self.v[key] = np.zeros_like(val)

        for key in self.get_trainable_param_keys(param_keys=params.keys()):
            self.v[key] *= self.momentum
            self.v[key] -= self.lr * grads[key]
            params[key] += self.momentum * self.momentum * self.v[key]
            params[key] -= (1 + self.momentum) * self.lr * grads[key]


class AdaGrad(OptimizerType):
    """AdaGrad"""

    def __init__(self, lr=0.01, **kwargs):
        super().__init__(**kwargs)
        self.lr = lr
        self.h = None

    def update(self, params, grads):
        if self.h is None:
            self.h = {}
            for key, val in params.items():
                self.h[key] = np.zeros_like(val)

        for key in self.get_trainable_param_keys(param_keys=params.keys()):
            self.h[key] += grads[key] * grads[key]
            params[key] -= self.lr * grads[key] / (np.sqrt(self.h[key]) + 1e-7)


class RMSprop(OptimizerType):
    """RMSprop"""

    def __init__(self, lr=0.01, decay_rate=0.99, **kwargs):
        super().__init__(**kwargs)
        self.lr = lr
        self.decay_rate = decay_rate
        self.h = None

    def update(self, params, grads):
        if self.h is None:
            self.h = {}
            for key, val in params.items():
                self.h[key] = np.zeros_like(val)

        for key in self.get_trainable_param_keys(param_keys=params.keys()):
            self.h[key] *= self.decay_rate
            self.h[key] += (1 - self.decay_rate) * grads[key] * grads[key]
            params[key] -= self.lr * grads[key] / (np.sqrt(self.h[key]) + 1e-7)


class Adam(OptimizerType):
    """Adam (http://arxiv.org/abs/1412.6980v8)"""

    def __init__(self, lr=0.001, beta1=0.9, beta2=0.999, **kwargs):
        super().__init__(**kwargs)
        self.lr = lr
        self.beta1 = beta1
        self.beta2 = beta2
        self.iter = 0
        self.m = None
        self.v = None

    def update(self, params, grads):
        if self.m is None:
            self.m, self.v = {}, {}
            for key, val in params.items():
                self.m[key] = np.zeros_like(val)
                self.v[key] = np.zeros_like(val)

        self.iter += 1
        lr_t = self.lr * np.sqrt(1.0 - self.beta2 ** self.iter) / (1.0 - self.beta1 ** self.iter)

        for key in self.get_trainable_param_keys(param_keys=params.keys()):
            # self.m[key] = self.beta1*self.m[key] + (1-self.beta1)*grads[key]
            # self.v[key] = self.beta2*self.v[key] + (1-self.beta2)*(grads[key]**2)
            self.m[key] += (1 - self.beta1) * (grads[key] - self.m[key])
            self.v[key] += (1 - self.beta2) * (grads[key] ** 2 - self.v[key])

            params[key] -= lr_t * self.m[key] / (np.sqrt(self.v[key]) + 1e-7)

            # unbias_m += (1 - self.beta1) * (grads[key] - self.m[key]) # correct bias
            # unbisa_b += (1 - self.beta2) * (grads[key]*grads[key] - self.v[key]) # correct bias
            # params[key] += self.lr * unbias_m / (np.sqrt(unbisa_b) + 1e-7)
