from __future__ import annotations

from torch import Tensor
from torch.nn import Conv2d, Identity, Module

from lightly_train._models.feature_extractor import FeatureExtractor


# EmbeddingModel is not combined into a single class with FeatureExtractor to keep
# implementing new extractors as simple as possible.
#
# Note that in the future we might want to support feature extractors that also generate
# features from intermediate layers, for this we'll have to add support for multiple
# embedding heads with different dimensions.
class EmbeddingModel(Module):
    def __init__(
        self,
        feature_extractor: FeatureExtractor,
        embed_dim: None | int = None,
        pool: bool = True,
    ):
        """A model that extracts features from input data and maps them to an embedding
        space.

        Args:
            feature_extractor:
                The feature extractor that extracts features from input data.
            embed_dim:
                The dimensionality of the embedding space. If None, the output of the
                feature extractor is used as the embedding.
            pool:
                Whether to apply the pooling layer of the feature extractor. If False,
                the features are embedded and returned without pooling.
        """
        super().__init__()
        self.feature_extractor = feature_extractor
        self.embed_head = (
            Identity()
            if embed_dim is None
            else Conv2d(
                in_channels=self.feature_extractor.feature_dim,
                out_channels=embed_dim,
                kernel_size=1,
            )
        )
        self.pool = pool
        self._forward_feature_extractor = (
            self.feature_extractor.forward
            if pool
            else self.feature_extractor.forward_features
        )

    @property
    def embed_dim(self) -> int:
        if isinstance(self.embed_head, Identity):
            return self.feature_extractor.feature_dim
        else:
            return self.embed_head.out_channels

    def forward(self, x: Tensor) -> Tensor:
        """Extract features from input image and map them to an embedding space.

        Args:
            x: Input images with shape (B, C, H_in, W_in).

        Returns:
            Embeddings with shape (B, embed_dim, H_out, W_out). H_out and W_out depend
            on the pooling layer of the feature extractor and are 1 in most cases.
        """
        x = self._forward_feature_extractor(x)
        x = self.embed_head(x)
        return x
