import torch
from torch import Tensor

from lightly_train.types import Transform

try:
    from torchvision.transforms import v2 as torchvision_transforms

    _TRANSFORMS_V2 = True

except ImportError:
    from torchvision import transforms as torchvision_transforms

    _TRANSFORMS_V2 = False


def ToTensor() -> Transform[Tensor]:
    T = torchvision_transforms
    if _TRANSFORMS_V2 and hasattr(T, "ToImage") and hasattr(T, "ToDtype"):
        # v2.transforms.ToTensor is deprecated and will be removed in the future.
        # This is the new recommended way to convert a PIL Image to a tensor since
        # torchvision v0.16.
        # See also https://github.com/pytorch/vision/blame/33e47d88265b2d57c2644aad1425be4fccd64605/torchvision/transforms/v2/_deprecated.py#L19
        T = torchvision_transforms
        return T.Compose([T.ToImage(), T.ToDtype(dtype=torch.float32, scale=True)])
    else:
        return T.ToTensor()
