from typing import Any, Dict
import torch.nn as nn
import numpy as np
from abc import abstractmethod
from numpy import inf
import torch
from pathlib import Path
from datetime import datetime
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from torch.utils.data.sampler import SubsetRandomSampler
from collections import OrderedDict
import os
import shutil

from .logger import TensorboardWriter, get_logger
from . import model as model_module


class BaseModel(nn.Module):
    """
    Base class for all models
    """

    @abstractmethod
    def forward(self, *inputs):
        """
        Forward pass logic
        :return: Model output
        """
        raise NotImplementedError

    def __str__(self):
        """
        Model prints with number of trainable parameters
        """
        model_parameters = filter(lambda p: p.requires_grad, self.parameters())
        params = sum([np.prod(p.size()) for p in model_parameters])
        return super().__str__() + "\nTrainable parameters: {}".format(params)


class BaseTrainer:
    """
    Base class for all trainers
    """

    def __init__(
        self,
        model,
        criterion,
        metric_ftns,
        optimizer,
        config,
        uuid,
        resume=None,
    ):
        self.config = config
        self.logger = get_logger("train")

        save_dir = Path("heads")
        self.run_id = datetime.now().strftime(r"%m%d_%H%M%S")

        self.checkpoint_dir = save_dir / uuid / "runs" / self.run_id
        self.log_dir = save_dir / uuid / "runs" / self.run_id / "logs"
        self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
        self.log_dir.mkdir(parents=True, exist_ok=True)

        self.logger = get_logger("trainer", config["trainer"]["verbosity"])

        self.model = model
        self.criterion = criterion
        self.metric_ftns = metric_ftns
        self.optimizer = optimizer

        cfg_trainer = config["trainer"]
        self.epochs = cfg_trainer["epochs"]
        self.save_period = cfg_trainer["save_period"]
        self.monitor = cfg_trainer.get("monitor", "off")

        # configuration to monitor model performance and save best
        if self.monitor == "off":
            self.mnt_mode = "off"
            self.mnt_best = 0
        else:
            self.mnt_mode, self.mnt_metric = self.monitor.split()
            assert self.mnt_mode in ["min", "max"]

            self.mnt_best = inf if self.mnt_mode == "min" else -inf
            self.early_stop = cfg_trainer.get("early_stop", inf)
            if self.early_stop <= 0:
                self.early_stop = inf

        self.start_epoch = 1

        # setup visualization writer instance
        self.writer = TensorboardWriter(
            self.log_dir, self.logger, cfg_trainer["tensorboard"]
        )

        if resume is not None:
            self._resume_checkpoint(resume)

    @abstractmethod
    def _train_epoch(self, epoch):
        """
        Training logic for an epoch
        :param epoch: Current epoch number
        """
        raise NotImplementedError

    def train(self):
        """
        Full training logic
        """
        not_improved_count = 0
        for epoch in range(self.start_epoch, self.epochs + 1):
            result = self._train_epoch(epoch)

            # save logged informations into log dict
            log = {"epoch": epoch}
            log.update(result)

            # print logged informations to the screen
            for key, value in log.items():
                self.logger.info("    {:15s}: {}".format(str(key), value))

            # evaluate model performance according to configured metric, save best checkpoint as model_best
            best = False
            if self.mnt_mode != "off":
                try:
                    # check whether model performance improved or not, according to specified metric(mnt_metric)
                    improved = (
                        self.mnt_mode == "min" and log[self.mnt_metric] <= self.mnt_best
                    ) or (
                        self.mnt_mode == "max" and log[self.mnt_metric] >= self.mnt_best
                    )
                except KeyError:
                    self.logger.warning(
                        "Warning: Metric '{}' is not found. "
                        "Model performance monitoring is disabled.".format(
                            self.mnt_metric
                        )
                    )
                    self.mnt_mode = "off"
                    improved = False

                if improved:
                    self.mnt_best = log[self.mnt_metric]
                    not_improved_count = 0
                    best = True
                else:
                    not_improved_count += 1

                if not_improved_count > self.early_stop:
                    self.logger.info(
                        "Validation performance didn't improve for {} epochs. "
                        "Training stops.".format(self.early_stop)
                    )
                    break

            if epoch % self.save_period == 0:
                self._save_checkpoint(epoch, save_best=best)

    def _save_checkpoint(self, epoch, save_best=False):
        """
        Saving checkpoints
        :param epoch: current epoch number
        :param log: logging information of the epoch
        :param save_best: if True, rename the saved checkpoint to 'model_best.pth'
        """
        arch = type(self.model).__name__
        state = {
            "arch": arch,
            "epoch": epoch,
            "state_dict": self.model.state_dict(),
            "optimizer": self.optimizer.state_dict(),
            "monitor_best": self.mnt_best,
            "config": self.config,
        }
        filename = str(self.checkpoint_dir / "checkpoint-epoch{}.pth".format(epoch))
        torch.save(state, filename)
        self.logger.info("Saving checkpoint: {} ...".format(filename))
        if save_best:
            best_path = str(self.checkpoint_dir / "model_best.pth")
            torch.save(state, best_path)
            self.logger.info("Saving current best: model_best.pth ...")

    def _resume_checkpoint(self, resume_path):
        """
        Resume from saved checkpoints
        :param resume_path: Checkpoint path to be resumed
        """
        resume_path = str(resume_path)
        self.logger.info("Loading checkpoint: {} ...".format(resume_path))
        checkpoint = torch.load(resume_path, map_location=torch.device("cpu"))
        self.start_epoch = checkpoint.get("epoch", 0) + 1
        self.epochs += checkpoint.get("epoch", 0)
        self.mnt_best = checkpoint.get("monitor_best", 0)

        # load architecture params from checkpoint.
        state_dict = checkpoint.get("state_dict", checkpoint)
        model_keys = list(self.model.state_dict().keys())
        if model_keys[0].startswith("model."):  # when using CustomModel
            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                name = "model." + k
                new_state_dict[name] = v
            state_dict = new_state_dict
        self.model.load_state_dict(state_dict)

        # load optimizer state from checkpoint only when optimizer type is not changed.
        if checkpoint.get("config", None) is not None:
            if (
                checkpoint["config"]["optimizer"]["type"]
                != self.config["optimizer"]["type"]
            ):
                self.logger.warning(
                    "Warning: Optimizer type given in config file is different from that of checkpoint. "
                    "Optimizer parameters not being resumed."
                )
            else:
                self.optimizer.load_state_dict(checkpoint["optimizer"])

        self.logger.info(
            "Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch)
        )


class BaseDataLoader(DataLoader):
    """
    Base class for all data loaders
    """

    def __init__(
        self,
        dataset,
        batch_size,
        shuffle,
        validation_split,
        num_workers,
        collate_fn=default_collate,
    ):
        self.validation_split = validation_split
        self.shuffle = shuffle

        self.batch_idx = 0
        self.n_samples = len(dataset)

        self.sampler, self.valid_sampler = self._split_sampler(self.validation_split)

        self.init_kwargs = {
            "dataset": dataset,
            "batch_size": batch_size,
            "shuffle": self.shuffle,
            "collate_fn": collate_fn,
            "num_workers": num_workers,
        }
        super().__init__(sampler=self.sampler, **self.init_kwargs)

    def _split_sampler(self, split):
        if split == 0.0:
            return None, None

        idx_full = np.arange(self.n_samples)

        np.random.seed(0)
        np.random.shuffle(idx_full)

        if isinstance(split, int):
            assert split > 0
            assert (
                split < self.n_samples
            ), "validation set size is configured to be larger than entire dataset."
            len_valid = split
        else:
            len_valid = int(self.n_samples * split)

        valid_idx = idx_full[0:len_valid]
        train_idx = np.delete(idx_full, np.arange(0, len_valid))

        train_sampler = SubsetRandomSampler(train_idx)
        valid_sampler = SubsetRandomSampler(valid_idx)

        # turn off shuffle option which is mutually exclusive with sampler
        self.shuffle = False
        self.n_samples = len(train_idx)

        return train_sampler, valid_sampler

    def split_validation(self):
        if self.valid_sampler is None:
            return None
        else:
            return DataLoader(sampler=self.valid_sampler, **self.init_kwargs)


class BaseInference:
    def __init__(self, config: Dict, logger, uuid: str = None) -> None:
        self.config = config
        model_config = self.config["Model"]
        self.infer_config = self.config["Inference"]

        self.model = getattr(model_module, model_config["type"])(
            **model_config.get("args", {})
        )

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # loading pretrained weights
        if self.infer_config["weights"]:
            if not os.path.exists(f"heads/{uuid}/current.pth"):
                shutil.copyfile(
                    self.infer_config["weights"], f"heads/{uuid}/current.pth"
                )
                logger.info(
                    "Loading default checkpoint: {} ...".format(
                        self.infer_config["weights"]
                    )
                )

            checkpoint = torch.load(
                f"heads/{uuid}/current.pth", map_location=torch.device("cpu")
            )
            state_dict = checkpoint.get("state_dict", checkpoint)
            model_keys = list(self.model.state_dict().keys())
            if model_keys[0].startswith("model."):
                new_state_dict = OrderedDict()
                for k, v in state_dict.items():
                    name = "model." + k
                    new_state_dict[name] = v
                state_dict = new_state_dict
            self.model.load_state_dict(state_dict)
            logger.info("Checkpoint loaded.")

        # Setting the device
        self.model = self.model.to(self.device)
        self.model.eval()

        # Get Intermediate Activations
        self.out_activation_layer = self.infer_config.get("out_activation_layer", None)
        if self.out_activation_layer:
            self.activation = {}
            getattr(self.model, self.out_activation_layer).register_forward_hook(
                self.get_activation(self.out_activation_layer)
            )

    @torch.no_grad()
    def predict(self, data: Any) -> Any:
        data = self.preprocess(data)
        data = data.to(self.device)
        output = self.model(data)
        if self.out_activation_layer:
            output = self.activation[self.out_activation_layer]
        output = self.postprocess(output)
        return output

    def preprocess(self, data: Any) -> Any:
        raise NotImplementedError

    def postprocess(self, data: Any) -> Any:
        raise NotImplementedError

    def load_weights(self, weights: str) -> None:
        checkpoint = torch.load(weights, map_location=torch.device("cpu"))
        state_dict = checkpoint.get("state_dict", checkpoint)
        self.model.load_state_dict(state_dict)

    def get_activation(self, name):
        def hook(model, input, output):
            self.activation[name] = output.detach()

        return hook
