"""This module provides the dataloader."""
import os
from typing import Optional, Tuple

import pandas as pd
import numpy as np
from numpy import typing as npt

from simba_ml.prediction.time_series.config.mixed_data_pipeline import (
    mixed_data_config,
)
from simba_ml.prediction import preprocessing
from simba_ml.prediction.time_series.data_loader import window_generator, splits


class MixedDataLoader:
    """Loads and preprocesses the data.

    Attributes:
        X_test: the input of the test data
        y_test: the labels for the test data
        train_validation_sets: list of validations sets, one for each ratio of
            synthethic to observed data
    """

    config: mixed_data_config.DataConfig
    __X_test: Optional[npt.NDArray[np.float64]] = None
    __y_test: Optional[npt.NDArray[np.float64]] = None
    __train_sets: Optional[dict[float, list[npt.NDArray[np.float64]]]] = None

    def __init__(self, config: mixed_data_config.DataConfig) -> None:
        """Inits the DataLoader.

        Args:
            config: the data configuration.
        """
        self.config = config

    def load_data(self) -> Tuple[list[pd.DataFrame], list[pd.DataFrame]]:
        """Loads the data.

        Returns:
            A list of dataframes.
        """
        synthetic = (
            []
            if self.config.synthetic is None
            else preprocessing.read_dataframes_from_csvs(
                os.getcwd() + self.config.synthetic
            )
        )
        observed = (
            []
            if self.config.observed is None
            else preprocessing.read_dataframes_from_csvs(
                os.getcwd() + self.config.observed
            )
        )
        return synthetic, observed

    def prepare_data(self) -> None:
        """This function preprocesses the data."""
        if self.__X_test is not None:  # pragma: no cover
            return  # pragma: no cover

        self.__train_sets = {}
        synthethic_data, observed_data = self.load_data()

        synthetic_train, _ = splits.train_test_split(
            data=synthethic_data,
            test_split=self.config.test_split,
            input_length=self.config.time_series.input_length,
            split_axis=self.config.split_axis,
        )
        observed_train, observed_test = splits.train_test_split(
            data=observed_data,
            test_split=self.config.test_split,
            input_length=self.config.time_series.input_length,
            split_axis=self.config.split_axis,
        )

        for ratio in self.config.ratios:
            train = preprocessing.convert_dataframe_to_numpy(
                preprocessing.mix_data(
                    synthetic_data=synthetic_train,
                    observed_data=observed_train,
                    ratio=ratio,
                )
            )
            self.__train_sets[ratio] = train

        test = preprocessing.convert_dataframe_to_numpy(observed_test)
        self.__X_test, self.__y_test = window_generator.create_window_dataset(
            test,
            self.config.time_series.input_length,
            self.config.time_series.output_length,
        )

    # sourcery skip: snake-case-functions
    @property
    def X_test(self) -> npt.NDArray[np.float64]:
        """The input of the test dataset.

        Returns:
            The input of the test dataset.
        """
        if self.__X_test is None:
            self.prepare_data()
            return self.X_test
        return self.__X_test

    @property
    def y_test(self) -> npt.NDArray[np.float64]:
        """The output of the test dataset.

        Returns:
            The output of the test dataset.
        """
        if self.__y_test is None:
            self.prepare_data()
            return self.y_test
        return self.__y_test

    @property
    def train_sets(
        self,
    ) -> dict[float, list[npt.NDArray[np.float64]]]:
        """Lists of train sets.

        One set for each ratio.

        Returns:
            A dict containing the train sets.
        """
        if not self.__train_sets:
            self.prepare_data()
            return self.train_sets
        return self.__train_sets
