from __future__ import annotations

from itertools import chain

from torch.nn import Module

from lightly_train._models.custom.custom_package import CUSTOM_PACKAGE
from lightly_train._models.feature_extractor import FeatureExtractor
from lightly_train._models.package import Package
from lightly_train._models.super_gradients.super_gradients_package import (
    SUPER_GRADIENTS_PACKAGE,
)
from lightly_train._models.timm.timm_package import TIMM_PACKAGE
from lightly_train._models.torchvision.torchvision_package import TORCHVISION_PACKAGE
from lightly_train.errors import UnknownModelError


def list_packages() -> list[Package]:
    """Lists all supported packages."""
    return [
        SUPER_GRADIENTS_PACKAGE,
        TIMM_PACKAGE,
        TORCHVISION_PACKAGE,
        # Custom package must be at end of list because we first want to check if a
        # model is part of one of the other packages. Custom is the last resort.
        CUSTOM_PACKAGE,
    ]


def get_package(package_name: str) -> Package:
    """Get a package by name."""
    # Don't include custom package. It should never be fetched by name.
    packages = {p.name: p for p in list_packages() if p != CUSTOM_PACKAGE}
    try:
        return packages[package_name]
    except KeyError:
        raise ValueError(
            f"Unknown package name: '{package_name}'. Supported packages are "
            f"{list(packages)}."
        )


def list_model_names() -> list[str]:
    """Lists all models in 'package_name/model_name' format."""
    return sorted(chain.from_iterable(p.list_model_names() for p in list_packages()))


def get_model(model: str | Module) -> Module:
    """Returns a model instance given a model name or instance."""
    if isinstance(model, Module):
        return model

    package_name, model_name = _parse_model_name(model=model)
    package = get_package(package_name=package_name)
    return package.get_model(model_name)


def get_feature_extractor_cls(model: Module) -> type[FeatureExtractor]:
    """Returns a feature extractor class for the given model."""
    for package in list_packages():
        if package.is_supported_model(model):
            return package.get_feature_extractor_cls(model)

    raise UnknownModelError(f"Unknown model: '{model.__class__.__name__}'")


def _parse_model_name(model: str) -> tuple[str, str]:
    parts = model.split("/")
    if len(parts) != 2:
        raise ValueError(
            "Model name has incorrect format. Should be 'package/model' but is "
            f"'{model}'"
        )
    package_name = parts[0]
    model_name = parts[1]
    return package_name, model_name
