import unittest

import numpy as np
from sklearn.datasets import load_iris
from sklearn.utils import (
    check_random_state,
    extmath,
)

from skmatter.metrics import (
    global_reconstruction_distortion,
    global_reconstruction_error,
    local_reconstruction_error,
    pointwise_local_reconstruction_error,
)


class ReconstructionMeasuresTests(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        features = load_iris().data
        cls.features_small = features[:20, [0, 1]]
        cls.features_large = features[:20, [0, 1, 0, 1]]
        cls.eps = 1e-5
        cls.n_local_points = 15

        random_state = 0
        random_state = check_random_state(random_state)
        random_orthonormal_mat = extmath.randomized_range_finder(
            np.eye(cls.features_small.shape[1]),
            size=cls.features_small.shape[1],
            n_iter=10,
            random_state=random_state,
        )
        cls.features_rotated_small = cls.features_small @ random_orthonormal_mat

    def test_global_reconstruction_error_identity(self):
        gfre_val = global_reconstruction_error(self.features_large, self.features_large)
        self.assertTrue(
            abs(gfre_val) < self.eps,
            f"global_reconstruction_error {gfre_val} surpasses threshold for zero {self.eps}",
        )

    def test_global_reconstruction_error_small_to_large(self):
        # tests that the GRE of a small set of features onto a larger set of features returns within a threshold of zero
        gfre_val = global_reconstruction_error(self.features_small, self.features_large)
        self.assertTrue(
            abs(gfre_val) < self.eps,
            f"global_reconstruction_error {gfre_val} surpasses threshold for zero {self.eps}",
        )

    def test_global_reconstruction_error_large_to_small(self):
        # tests that the GRE of a large set of features onto a smaller set of features returns within a threshold of zero
        gfre_val = global_reconstruction_error(self.features_large, self.features_small)
        self.assertTrue(
            abs(gfre_val) < self.eps,
            f"global_reconstruction_error {gfre_val} surpasses threshold for zero {self.eps}",
        )

    def test_global_reconstruction_distortion_identity(self):
        # tests that the GRD of a set of features onto itself returns within a threshold of zero
        gfrd_val = global_reconstruction_distortion(
            self.features_large, self.features_large
        )
        self.assertTrue(
            abs(gfrd_val) < self.eps,
            f"global_reconstruction_error {gfrd_val} surpasses threshold for zero {self.eps}",
        )

    def test_global_reconstruction_distortion_small_to_large(self):
        # tests that the GRD of a small set of features onto a larger set of features returns within a threshold of zero
        # should just run
        global_reconstruction_error(self.features_small, self.features_large)

    def test_global_reconstruction_distortion_large_to_small(self):
        # tests that the GRD of a large set of features onto a smaller set of features returns within a threshold of zero
        # should just run
        global_reconstruction_error(self.features_large, self.features_small)

    def test_global_reconstruction_distortion_small_to_rotated_small(self):
        # tests that the GRD of a small set of features onto a rotation of itself returns within a threshold of zero
        gfrd_val = global_reconstruction_distortion(
            self.features_small, self.features_rotated_small
        )
        self.assertTrue(
            abs(gfrd_val) < self.eps,
            f"global_reconstruction_error {gfrd_val} surpasses threshold for zero {self.eps}",
        )

    def test_local_reconstruction_error_identity(self):
        # tests that the local reconstruction error of a set of features onto itself returns within a threshold of zero

        lfre_val = local_reconstruction_error(
            self.features_large, self.features_large, self.n_local_points
        )
        self.assertTrue(
            abs(lfre_val) < self.eps,
            f"local_reconstruction_error {lfre_val} surpasses threshold for zero {self.eps}",
        )

    def test_local_reconstruction_error_small_to_large(self):
        # tests that the local reconstruction error of a small set of features onto a larger set of features returns within a threshold of zero

        lfre_val = local_reconstruction_error(
            self.features_small, self.features_large, self.n_local_points
        )
        self.assertTrue(
            abs(lfre_val) < self.eps,
            f"local_reconstruction_error {lfre_val} surpasses threshold for zero {self.eps}",
        )

    def test_local_reconstruction_error_large_to_small(self):
        # tests that the local reconstruction error of a larger set of features onto a smaller set of features returns within a threshold of zero

        lfre_val = local_reconstruction_error(
            self.features_large, self.features_small, self.n_local_points
        )
        self.assertTrue(
            abs(lfre_val) < self.eps,
            f"local_reconstruction_error {lfre_val} surpasses threshold for zero {self.eps}",
        )

    def test_local_reconstruction_error_train_idx(self):
        # tests that the local reconstruction error works when specifying a manual train idx

        lfre_val = pointwise_local_reconstruction_error(
            self.features_large,
            self.features_large,
            self.n_local_points,
            train_idx=np.arange((len(self.features_large) // 4)),
        )
        test_size = len(self.features_large) - (len(self.features_large) // 4)
        self.assertTrue(
            len(lfre_val) == test_size,
            f"size of pointwise LFRE  {len(lfre_val)} differs from expected test set size {test_size}",
        )

    def test_local_reconstruction_error_test_idx(self):
        # tests that the local reconstruction error works when specifying a manual train idx

        lfre_val = pointwise_local_reconstruction_error(
            self.features_large,
            self.features_large,
            self.n_local_points,
            test_idx=np.arange((len(self.features_large) // 4)),
        )
        test_size = len(self.features_large) // 4
        self.assertTrue(
            len(lfre_val) == test_size,
            f"size of pointwise LFRE  {len(lfre_val)} differs from expected test set size {test_size}",
        )


if __name__ == "__main__":
    unittest.main()
