import torch
import pytest

from piq import MSID
from piq.feature_extractors import InceptionV3


class TestDataset(torch.utils.data.Dataset):
    def __init__(self):
        self.data = torch.rand(15, 3, 256, 256)
        self.mask = torch.rand(15, 3, 256, 256)

    def __getitem__(self, index):
        x = self.data[index]
        y = self.mask[index]

        return {'images': x, 'mask': y}

    def __len__(self):
        return len(self.data)


@pytest.fixture(scope='module')
def x() -> torch.Tensor:
    return torch.rand(3, 3, 256, 256)


@pytest.fixture(scope='module')
def y() -> torch.Tensor:
    return torch.rand(3, 3, 256, 256)


@pytest.fixture(scope='module')
def features_y_normal() -> torch.Tensor:
    return torch.rand(1000, 20)


@pytest.fixture(scope='module')
def features_x_normal() -> torch.Tensor:
    return torch.rand(1000, 20)


@pytest.fixture(scope='module')
def features_x_beta() -> torch.Tensor:
    m = torch.distributions.Beta(torch.FloatTensor([2]), torch.FloatTensor([2]))
    return m.sample([1000, 20]).squeeze()


@pytest.fixture(scope='module')
def features_x_constant() -> torch.Tensor:
    return torch.ones(1000, 20)


# ================== Test class: `MSID` ==================
def test_fails_for_different_dimensions(features_y_normal) -> None:
    features_x_normal = torch.rand(1000, 21)
    metric = MSID()
    with pytest.raises(AssertionError):
        metric(features_y_normal, features_x_normal)


def test_compute_msid_works_for_different_number_of_images_in_stack(features_y_normal) -> None:
    features_x_normal = torch.rand(1001, 20)
    metric = MSID()
    try:
        metric(features_y_normal, features_x_normal)
    except Exception as e:
        pytest.fail(f"Unexpected error occurred: {e}")


def test_initialization() -> None:
    try:
        MSID()
    except Exception as e:
        pytest.fail(f"Unexpected error occurred: {e}")


@pytest.mark.skip(reason="Sometimes it doesn't work.")
def test_msid_is_smaller_for_equal_tensors(features_y_normal, features_x_normal, features_x_constant) -> None:
    metric = MSID()
    measure = metric(features_y_normal, features_x_normal)
    measure_constant = metric(features_y_normal, features_x_normal)
    assert measure <= measure_constant, \
        f'MSID should be smaller for samples from the same distribution, got {measure} and {measure_constant}'


def test_forward(features_y_normal, features_x_normal, ) -> None:
    try:
        metric = MSID()
        metric(features_y_normal, features_x_normal)
    except Exception as e:
        pytest.fail(f"Unexpected error occurred: {e}")


def test_compute_feats_cpu() -> None:
    try:
        dataset = TestDataset()
        loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=3,
            num_workers=2,
        )
        metric = MSID()
        model = InceptionV3()
        metric.compute_feats(loader, model, device='cpu')
    except Exception as e:
        pytest.fail(f"Unexpected error occurred: {e}")


@pytest.mark.skipif(not torch.cuda.is_available(), reason='No need to run test on GPU if there is no GPU.')
def test_compute_feats_cuda() -> None:
    try:
        dataset = TestDataset()
        loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=3,
            num_workers=2,
        )
        metric = MSID()
        model = InceptionV3()
        metric.compute_feats(loader, model, device='cuda')
    except Exception as e:
        pytest.fail(f"Unexpected error occurred: {e}")
