import unittest
import xarray as xr
import numpy as np

import deepsensor.tensorflow as deepsensor
from deepsensor.active_learning.acquisition_fns import MeanVariance
from deepsensor.active_learning.algorithms import GreedyAlgorithm

from deepsensor.data.loader import TaskLoader
from deepsensor.data.processor import DataProcessor
from deepsensor.data.task import append_obs_to_task
from deepsensor.errors import TaskSetIndexError, GriddedDataError
from deepsensor.model.convnp import ConvNP


# from deepsensor.active_learning.acquisition_fns import


class TestConcatTasks(unittest.TestCase):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # It's safe to share data between tests because the TaskLoader does not modify data
        ds_raw = xr.tutorial.open_dataset("air_temperature")
        self.ds_raw = ds_raw
        self.data_processor = DataProcessor(
            x1_name="lat",
            x2_name="lon",
            x1_map=(ds_raw["lat"].min(), ds_raw["lat"].max()),
            x2_map=(ds_raw["lon"].min(), ds_raw["lon"].max()),
        )
        ds = self.data_processor(ds_raw)
        self.task_loader = TaskLoader(ds, ds)
        self.model = ConvNP(
            self.data_processor,
            self.task_loader,
            unet_channels=(5, 5, 5),
            verbose=False,
        )
        self.task = self.task_loader("2014-12-31")

    def test_concat_obs_to_task_shapes(self):
        ctx_idx = 0  # Context set index to add new observations to

        # Sample 10 context observations
        task = self.task_loader("2014-12-31", context_sampling=10)

        # 1 context observation
        X_new = np.random.randn(2, 1)
        Y_new = np.random.randn(1, 1)
        new_task = append_obs_to_task(task, X_new, Y_new, ctx_idx)
        self.assertEqual(new_task["X_c"][ctx_idx].shape, (2, 11))
        self.assertEqual(new_task["Y_c"][ctx_idx].shape, (1, 11))

        # 1 context observation with flattened obs dim
        X_new = np.random.randn(2)
        Y_new = np.random.randn(1)
        new_task = append_obs_to_task(task, X_new, Y_new, ctx_idx)
        self.assertEqual(new_task["X_c"][ctx_idx].shape, (2, 11))
        self.assertEqual(new_task["Y_c"][ctx_idx].shape, (1, 11))

        # 5 context observations
        X_new = np.random.randn(2, 5)
        Y_new = np.random.randn(1, 5)
        new_task = append_obs_to_task(task, X_new, Y_new, ctx_idx)
        self.assertEqual(new_task["X_c"][ctx_idx].shape, (2, 15))
        self.assertEqual(new_task["Y_c"][ctx_idx].shape, (1, 15))

    def test_concat_obs_to_task_wrong_context_index(self):
        # Sample 10 context observations
        task = self.task_loader("2014-12-31", context_sampling=10)

        ctx_idx = 1  # Wrong context set index

        # 1 context observation
        X_new = np.random.randn(2, 1)
        Y_new = np.random.randn(1, 1)

        with self.assertRaises(TaskSetIndexError):
            _ = append_obs_to_task(task, X_new, Y_new, ctx_idx)

    def test_concat_obs_to_task_fails_for_gridded_data(self):
        ctx_idx = 0  # Context set index to add new observations to

        # Sample context observations on a grid
        task = self.task_loader("2014-12-31", context_sampling="all")

        # Confirm that context observations are gridded with tuple for coordinates
        assert isinstance(task["X_c"][ctx_idx], tuple)

        # 1 context observation
        X_new = np.random.randn(2, 1)
        Y_new = np.random.randn(1, 1)

        with self.assertRaises(GriddedDataError):
            new_task = append_obs_to_task(task, X_new, Y_new, ctx_idx)


class TestActiveLearning(unittest.TestCase):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # It's safe to share data between tests because the TaskLoader does not modify data
        ds_raw = xr.tutorial.open_dataset("air_temperature")
        self.ds_raw = ds_raw
        self.data_processor = DataProcessor(
            x1_name="lat",
            x2_name="lon",
            x1_map=(ds_raw["lat"].min(), ds_raw["lat"].max()),
            x2_map=(ds_raw["lon"].min(), ds_raw["lon"].max()),
        )
        ds = self.data_processor(ds_raw)
        self.task_loader = TaskLoader(ds, ds)
        self.model = ConvNP(
            self.data_processor,
            self.task_loader,
            unet_channels=(5, 5, 5),
            verbose=False,
        )
        self.task = self.task_loader("2014-12-31")

    def test_wrong_n_new_sensors(self):
        with self.assertRaises(ValueError):
            alg = GreedyAlgorithm(
                model=self.model,
                X_t=self.ds_raw,
                X_s=self.ds_raw,
                N_new_context=-1,
            )

        with self.assertRaises(ValueError):
            alg = GreedyAlgorithm(
                model=self.model,
                X_t=self.ds_raw,
                X_s=self.ds_raw,
                N_new_context=10_000,  # > number of search points
            )
