from __future__ import annotations

import warnings
from datetime import timedelta
from typing import Any

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint as _ModelCheckpoint

from lightly_train._checkpoint import (
    CHECKPOINT_LIGHTLY_TRAIN_KEY,
    CheckpointLightlyTrain,
    CheckpointLightlyTrainModels,
)
from lightly_train.types import PathLike


class ModelCheckpoint(_ModelCheckpoint):
    def __init__(
        self,
        models: CheckpointLightlyTrainModels,
        dirpath: None | PathLike = None,
        filename: None | str = None,
        monitor: None | str = None,
        verbose: bool = False,
        # Note: the type of save_last depends on the version of pytorch_lightning.
        # Only later versions also allow Literal["link"] as type.
        save_last: None | bool = None,
        save_top_k: int = 1,
        save_weights_only: bool = False,
        mode: str = "min",
        auto_insert_metric_name: bool = True,
        every_n_train_steps: None | int = None,
        train_time_interval: None | timedelta = None,
        every_n_epochs: None | int = None,
        save_on_train_epoch_end: None | bool = None,
        enable_version_counter: bool = True,
    ):
        super().__init__(
            dirpath=dirpath,
            filename=filename,
            monitor=monitor,
            verbose=verbose,
            save_last=save_last,
            save_top_k=save_top_k,
            save_weights_only=save_weights_only,
            mode=mode,
            auto_insert_metric_name=auto_insert_metric_name,
            every_n_train_steps=every_n_train_steps,
            train_time_interval=train_time_interval,
            every_n_epochs=every_n_epochs,
            save_on_train_epoch_end=save_on_train_epoch_end,
            enable_version_counter=enable_version_counter,
        )
        self._models = models

    def on_save_checkpoint(
        self, trainer: Trainer, pl_module: LightningModule, checkpoint: dict[str, Any]
    ) -> None:
        super().on_save_checkpoint(trainer, pl_module, checkpoint)
        checkpoint[CHECKPOINT_LIGHTLY_TRAIN_KEY] = CheckpointLightlyTrain.from_now(
            models=self._models
        ).to_dict()

    def on_load_checkpoint(
        self, trainer: Trainer, pl_module: LightningModule, checkpoint: dict[str, Any]
    ) -> None:
        super().on_load_checkpoint(trainer, pl_module, checkpoint)
        try:
            self._models = CheckpointLightlyTrain.from_checkpoint_dict(
                checkpoint
            ).models
        except KeyError as ex:
            warnings.warn(
                f"Could not restore lightly_train models from checkpoint: {ex}"
            )
            pass
