from __future__ import annotations

from torch.nn import Module

from lightly_train._models.feature_extractor import FeatureExtractor
from lightly_train._models.package import Package
from lightly_train._models.super_gradients.customizable_detector import (
    CustomizableDetectorFeatureExtractor,
)
from lightly_train.errors import UnknownModelError


class SuperGradientsPackage(Package):
    name = "super_gradients"

    # Sadly SuperGradients doesn't expose a common interface for all models. We have to
    # define different feature extractors depending on the model types.
    _FEATURE_EXTRACTORS = [CustomizableDetectorFeatureExtractor]

    @classmethod
    def list_model_names(cls) -> list[str]:
        try:
            from super_gradients.training import models  # type: ignore[import-untyped]
        except ImportError:
            return []
        model_names = {
            f"{cls.name}/{model_name}"
            for model_name, model_cls in models.ARCHITECTURES.items()
            if cls.is_supported_model_cls(model_cls=model_cls)
        }
        return sorted(model_names)

    @classmethod
    def is_supported_model(cls, model: Module) -> bool:
        return cls.is_supported_model_cls(model_cls=type(model))

    @classmethod
    def is_supported_model_cls(cls, model_cls: type[Module]) -> bool:
        return any(
            fe for fe in cls._FEATURE_EXTRACTORS if fe.is_supported_model_cls(model_cls)
        )

    @classmethod
    def get_model(cls, model_name: str) -> Module:
        try:
            from super_gradients.training import models
        except ImportError:
            raise ValueError(
                f"Cannot create model '{model_name}' because '{cls.name}' is not "
                "installed."
            )
        # TODO(Guarin, 07/2024): Expose `num_classes` with `model_args` parameter.
        return models.get(model_name=model_name, num_classes=10)

    @classmethod
    def get_feature_extractor_cls(cls, model: Module) -> type[FeatureExtractor]:
        for fe in cls._FEATURE_EXTRACTORS:
            if fe.is_supported_model_cls(model_cls=type(model)):
                return fe
        raise UnknownModelError(f"Unknown {cls.name} model: '{type(model)}'")


# Create singleton instance of the package. The singleton should be used whenever
# possible.
SUPER_GRADIENTS_PACKAGE = SuperGradientsPackage()
