# Copyright 2020 The FastEstimator Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import unittest

import numpy as np
import tensorflow as tf
import torch

import fastestimator as fe


class TestCategoricalCrossEntropy(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls.tf_true = tf.constant([[0, 1, 0], [1, 0, 0], [0, 0, 1]])
        cls.tf_pred = tf.constant([[0.1, 0.8, 0.1], [0.9, 0.05, 0.05], [0.1, 0.2, 0.7]])
        cls.tf_weights = tf.lookup.StaticHashTable(
            tf.lookup.KeyValueTensorInitializer(tf.constant([1, 2]), tf.constant([2.0, 3.0])), default_value=1.0)
        cls.torch_true = torch.tensor([[0, 1, 0], [1, 0, 0], [0, 0, 1]])
        cls.torch_pred = torch.tensor([[0.1, 0.8, 0.1], [0.9, 0.05, 0.05], [0.1, 0.2, 0.7]])
        cls.torch_weights = {1: 2.0, 2: 3.0}

    def test_categorical_crossentropy_average_loss_true_tf(self):
        obj1 = fe.backend.categorical_crossentropy(y_pred=self.tf_pred, y_true=self.tf_true).numpy()
        obj2 = 0.22839302
        self.assertTrue(np.allclose(obj1, obj2))

    def test_categorical_crossentropy_average_loss_false_tf(self):
        obj1 = fe.backend.categorical_crossentropy(y_pred=self.tf_pred, y_true=self.tf_true, average_loss=False).numpy()
        obj2 = np.array([0.22314353, 0.10536055, 0.35667497])
        self.assertTrue(np.allclose(obj1, obj2))

    def test_categorical_crossentropy_average_loss_false_weights_tf(self):
        obj1 = fe.backend.categorical_crossentropy(y_pred=self.tf_pred,
                                                   y_true=self.tf_true,
                                                   average_loss=False,
                                                   class_weights=self.tf_weights).numpy()
        obj2 = np.array([0.44628706, 0.10536055, 1.07002491])
        self.assertTrue(np.allclose(obj1, obj2))

    def test_categorical_crossentropy_from_logits_average_loss_true_tf(self):
        obj1 = fe.backend.categorical_crossentropy(y_pred=self.tf_pred,
                                                   y_true=self.tf_true,
                                                   average_loss=True,
                                                   from_logits=True).numpy()

        obj2 = 0.69182307
        self.assertTrue(np.allclose(obj1, obj2))

    def test_categorical_crossentropy_from_logits_average_loss_false_tf(self):
        obj1 = fe.backend.categorical_crossentropy(y_pred=self.tf_pred,
                                                   y_true=self.tf_true,
                                                   average_loss=False,
                                                   from_logits=True).numpy()

        obj2 = np.array([0.6897267, 0.6177929, 0.7679496])
        self.assertTrue(np.allclose(obj1, obj2))

    def test_categorical_crossentropy_from_logits_average_loss_false_weights_tf(self):
        obj1 = fe.backend.categorical_crossentropy(y_pred=self.tf_pred,
                                                   y_true=self.tf_true,
                                                   average_loss=False,
                                                   from_logits=True,
                                                   class_weights=self.tf_weights).numpy()

        obj2 = np.array([1.3794534, 0.6177929, 2.3038488])
        self.assertTrue(np.allclose(obj1, obj2))

    def test_categorical_crossentropy_average_loss_true_torch(self):
        obj1 = fe.backend.categorical_crossentropy(y_pred=self.torch_pred, y_true=self.torch_true).numpy()
        obj2 = 0.22839302
        self.assertTrue(np.allclose(obj1, obj2))

    def test_categorical_crossentropy_average_loss_false_torch(self):
        obj1 = fe.backend.categorical_crossentropy(y_pred=self.torch_pred, y_true=self.torch_true,
                                                   average_loss=False).numpy()
        obj2 = np.array([0.22314353, 0.10536055, 0.35667497])
        self.assertTrue(np.allclose(obj1, obj2))

    def test_categorical_crossentropy_average_loss_false_weights_torch(self):
        obj1 = fe.backend.categorical_crossentropy(y_pred=self.torch_pred,
                                                   y_true=self.torch_true,
                                                   average_loss=False,
                                                   class_weights=self.torch_weights).numpy()
        obj2 = np.array([0.44628706, 0.10536055, 1.07002491])
        self.assertTrue(np.allclose(obj1, obj2))

    def test_categorical_crossentropy_from_logits__average_loss_true_torch(self):
        obj1 = fe.backend.categorical_crossentropy(y_pred=self.torch_pred,
                                                   y_true=self.torch_true,
                                                   average_loss=True,
                                                   from_logits=True).numpy()

        obj2 = 0.69182307
        self.assertTrue(np.allclose(obj1, obj2))

    def test_categorical_crossentropy_from_logits_average_loss_false_torch(self):
        obj1 = fe.backend.categorical_crossentropy(y_pred=self.torch_pred,
                                                   y_true=self.torch_true,
                                                   average_loss=False,
                                                   from_logits=True).numpy()

        obj2 = np.array([0.6897267, 0.6177929, 0.7679496])
        self.assertTrue(np.allclose(obj1, obj2))

    def test_categorical_crossentropy_from_logits_average_loss_false_weights_torch(self):
        obj1 = fe.backend.categorical_crossentropy(y_pred=self.torch_pred,
                                                   y_true=self.torch_true,
                                                   average_loss=False,
                                                   from_logits=True,
                                                   class_weights=self.torch_weights).numpy()

        obj2 = np.array([1.3794534, 0.6177929, 2.3038488])
        self.assertTrue(np.allclose(obj1, obj2))
