from torch import Tensor
from torchvision.models import ResNet
from torchvision.models._utils import IntermediateLayerGetter

from lightly_train._models.torchvision.torchvision import (
    TorchvisionFeatureExtractor,
)


class ResNetFeatureExtractor(TorchvisionFeatureExtractor):
    _torchvision_models = [ResNet]
    _torchvision_model_name_pattern = r"resnet.*"

    def __init__(self, model: ResNet) -> None:
        super().__init__()
        self._features = IntermediateLayerGetter(
            model=model, return_layers={"layer4": "out"}
        )
        self._pool = model.avgpool
        self._feature_dim = model.fc.in_features

    @property
    def feature_dim(self) -> int:
        return self._feature_dim

    def forward_features(self, x: Tensor) -> Tensor:
        return self._features(x)["out"]

    def forward_pool(self, x: Tensor) -> Tensor:
        return self._pool(x)
