from typing import List, Optional

import torch
import torch.nn as nn
from model import ETSFormer
from gluonts.core.component import validated
from gluonts.time_feature import get_lags_for_frequency
from gluonts.torch.distributions import DistributionOutput, StudentTOutput
from gluonts.torch.modules.feature import FeatureEmbedder
from gluonts.torch.modules.scaler import MeanScaler, NOPScaler


class ETSformerModel(nn.Module):
    @validated()
    def __init__(
            self,
            freq: str,
            context_length: int,
            prediction_length: int,
            num_feat_dynamic_real: int,
            num_feat_static_real: int,
            num_feat_static_cat: int,
            cardinality: List[int],
            # ETSformer arguments
            k_largest_amplitudes: int,
            embed_kernel_size: int,
            nhead: int,
            num_layers: int,
            dropout: float = 0.1,
            # univariate input
            input_size: int = 1,
            embedding_dimension: Optional[List[int]] = None,
            distr_output: DistributionOutput = StudentTOutput(),
            lags_seq: Optional[List[int]] = None,
            scaling: bool = True,
            num_parallel_samples: int = 100,
    ) -> None:
        super().__init__()

        self.input_size = input_size

        self.target_shape = distr_output.event_shape
        self.num_feat_dynamic_real = num_feat_dynamic_real
        self.num_feat_static_cat = num_feat_static_cat
        self.num_feat_static_real = num_feat_static_real
        self.embedding_dimension = (
            embedding_dimension
            if embedding_dimension is not None or cardinality is None
            else [min(50, (cat + 1) // 2) for cat in cardinality]
        )
        self.lags_seq = lags_seq or get_lags_for_frequency(freq_str=freq)
        self.num_parallel_samples = num_parallel_samples
        self.history_length = context_length + max(self.lags_seq)
        self.embedder = FeatureEmbedder(
            cardinalities=cardinality,
            embedding_dims=self.embedding_dimension,
        )
        if scaling:
            self.scaler = MeanScaler(dim=1, keepdim=True)
        else:
            self.scaler = NOPScaler(dim=1, keepdim=True)

        # total feature size
        d_model = self.input_size * len(self.lags_seq) + self._number_of_features

        self.context_length = context_length
        self.prediction_length = prediction_length
        self.distr_output = distr_output
        self.param_proj = distr_output.get_args_proj(d_model)

        # ETSformer enc-decoder
        self.etsformer = ETSFormer(
            time_features=d_model,
            d_model=d_model,
            embed_kernel_size=embed_kernel_size,
            K=k_largest_amplitudes,
            layers=num_layers,
            n_heads=nhead,
            dropout=dropout,
        )

    @property
    def _number_of_features(self) -> int:
        return (
                sum(self.embedding_dimension)
                + self.num_feat_dynamic_real
                + self.num_feat_static_real
                + self.input_size  # the log(scale)
        )

    @property
    def _past_length(self) -> int:
        return self.context_length + max(self.lags_seq)

    def get_lagged_subsequences(
            self, sequence: torch.Tensor, subsequences_length: int, shift: int = 0
    ) -> torch.Tensor:
        """
        Returns lagged subsequences of a given sequence.
        Parameters
        ----------
        sequence : Tensor
            the sequence from which lagged subsequences should be extracted.
            Shape: (N, T, C).
        subsequences_length : int
            length of the subsequences to be extracted.
        shift: int
            shift the lags by this amount back.
        Returns
        --------
        lagged : Tensor
            a tensor of shape (N, S, C, I), where S = subsequences_length and
            I = len(indices), containing lagged subsequences. Specifically,
            lagged[i, j, :, k] = sequence[i, -indices[k]-S+j, :].
        """
        sequence_length = sequence.shape[1]
        indices = [lag - shift for lag in self.lags_seq]

        assert max(indices) + subsequences_length <= sequence_length, (
            f"lags cannot go further than history length, found lag {max(indices)} "
            f"while history length is only {sequence_length}"
        )

        lagged_values = []
        for lag_index in indices:
            begin_index = -lag_index - subsequences_length
            end_index = -lag_index if lag_index > 0 else None
            lagged_values.append(sequence[:, begin_index:end_index, ...])
        return torch.stack(lagged_values, dim=-1)

    def _check_shapes(
            self,
            prior_input: torch.Tensor,
            inputs: torch.Tensor,
            features: Optional[torch.Tensor],
    ) -> None:
        assert len(prior_input.shape) == len(inputs.shape)
        assert (
                       len(prior_input.shape) == 2 and self.input_size == 1
               ) or prior_input.shape[2] == self.input_size
        assert (len(inputs.shape) == 2 and self.input_size == 1) or inputs.shape[
            -1
        ] == self.input_size
        assert (
                features is None or features.shape[2] == self._number_of_features
        ), f"{features.shape[2]}, expected {self._number_of_features}"

    def create_network_inputs(
            self,
            feat_static_cat: torch.Tensor,
            feat_static_real: torch.Tensor,
            past_time_feat: torch.Tensor,
            past_target: torch.Tensor,
            past_observed_values: torch.Tensor,
            future_time_feat: Optional[torch.Tensor] = None,
            future_target: Optional[torch.Tensor] = None,
    ):
        # time feature
        time_feat = (
            torch.cat(
                (
                    past_time_feat[:, self._past_length - self.context_length:, ...],
                    future_time_feat,
                ),
                dim=1,
            )
            if future_target is not None
            else past_time_feat[:, self._past_length - self.context_length:, ...]
        )

        # target
        context = past_target[:, -self.context_length:]
        observed_context = past_observed_values[:, -self.context_length:]
        _, scale = self.scaler(context, observed_context)

        inputs = (
            torch.cat((past_target, future_target), dim=1) / scale
            if future_target is not None
            else past_target / scale
        )

        inputs_length = (
            self._past_length + self.prediction_length
            if future_target is not None
            else self._past_length
        )
        assert inputs.shape[1] == inputs_length

        subsequences_length = (
            self.context_length + self.prediction_length
            if future_target is not None
            else self.context_length
        )

        # embeddings
        embedded_cat = self.embedder(feat_static_cat)
        log_scale = scale.log() if self.input_size == 1 else scale.squeeze(1).log()
        static_feat = torch.cat(
            (embedded_cat, feat_static_real, log_scale),
            dim=1,
        )
        expanded_static_feat = static_feat.unsqueeze(1).expand(
            -1, time_feat.shape[1], -1
        )

        features = torch.cat((expanded_static_feat, time_feat), dim=-1)

        # self._check_shapes(prior_input, inputs, features)

        # sequence = torch.cat((prior_input, inputs), dim=1)
        lagged_sequence = self.get_lagged_subsequences(
            sequence=inputs,
            subsequences_length=subsequences_length,
        )

        lags_shape = lagged_sequence.shape
        reshaped_lagged_sequence = lagged_sequence.reshape(
            lags_shape[0], lags_shape[1], -1
        )

        transformer_inputs = torch.cat((reshaped_lagged_sequence, features), dim=-1)

        return transformer_inputs, scale, static_feat

    def output_params(self, transformer_inputs):
        enc_input = transformer_inputs[:, : self.context_length, ...]

        dec_output = self.etsformer(
            enc_input, num_steps_forecast=self.prediction_length
        )

        return self.param_proj(dec_output)

    @torch.jit.ignore
    def output_distribution(
            self, params, scale=None, trailing_n=None
    ) -> torch.distributions.Distribution:
        sliced_params = params
        if trailing_n is not None:
            sliced_params = [p[:, -trailing_n:] for p in params]
        return self.distr_output.distribution(sliced_params, scale=scale)

    # for prediction
    def forward(
            self,
            feat_static_cat: torch.Tensor,
            feat_static_real: torch.Tensor,
            past_time_feat: torch.Tensor,
            past_target: torch.Tensor,
            past_observed_values: torch.Tensor,
            future_time_feat: torch.Tensor,
            num_parallel_samples: Optional[int] = None,
    ) -> torch.Tensor:

        if num_parallel_samples is None:
            num_parallel_samples = self.num_parallel_samples

        encoder_inputs, scale, _ = self.create_network_inputs(
            feat_static_cat,
            feat_static_real,
            past_time_feat,
            past_target,
            past_observed_values,
        )

        dec_out = self.etsformer(
            encoder_inputs, num_steps_forecast=self.prediction_length
        )
        params = self.param_proj(dec_out)

        repeated_params = [
            s.repeat_interleave(repeats=self.num_parallel_samples, dim=0)
            for s in params
        ]

        repeated_scale = scale.repeat_interleave(
            repeats=self.num_parallel_samples, dim=0
        )

        distr = self.output_distribution(repeated_params, scale=repeated_scale)

        # Future samples
        samples = distr.sample()

        return samples.reshape(
            (-1, self.num_parallel_samples, self.prediction_length) + self.target_shape,
        )
