import pytest

try:
    from super_gradients.training import models  # type: ignore[import-untyped]
    from super_gradients.training.models import (  # type: ignore[import-untyped]
        YoloNAS_S,
    )
except ImportError:
    # We do not use pytest.importorskip on module level because it makes mypy unhappy.
    pytest.skip("super_gradients is not installed", allow_module_level=True)

from lightly_train._models.super_gradients.customizable_detector import (
    CustomizableDetectorFeatureExtractor,
)
from lightly_train._models.super_gradients.super_gradients_package import (
    SuperGradientsPackage,
)


class TestSuperGradientsPackage:
    def test_list_model_names(self) -> None:
        model_names = SuperGradientsPackage.list_model_names()
        assert "super_gradients/yolo_nas_s" in model_names
        assert "super_gradients/yolox_n" not in model_names

    @pytest.mark.parametrize(
        "model_name, is_supported", [("yolo_nas_s", True), ("yolox_n", False)]
    )
    def test_is_supported_model(self, model_name: str, is_supported: bool) -> None:
        model = models.get(model_name, num_classes=2)
        assert SuperGradientsPackage.is_supported_model(model) is is_supported

    def test_get_model(self) -> None:
        model = SuperGradientsPackage.get_model("yolo_nas_s")
        assert isinstance(model, YoloNAS_S)

    def test_get_feature_extractor_cls(self) -> None:
        model = models.get("yolo_nas_s", num_classes=2)
        fe_cls = SuperGradientsPackage.get_feature_extractor_cls(model)
        assert fe_cls is CustomizableDetectorFeatureExtractor
