from __future__ import annotations

import subprocess
from pathlib import Path

import numpy as np
from lightly_train._checkpoint import (
    Checkpoint,
    CheckpointLightlyTrain,
    CheckpointLightlyTrainModels,
)
from lightly_train._commands import extract_video_frames
from lightly_train._methods.method import Method
from lightly_train._methods.simclr import SimCLR, SimCLRArgs
from lightly_train._models import (
    package_helpers as feature_extractor_api,
)
from lightly_train._models.embedding_model import EmbeddingModel
from lightly_train._models.feature_extractor import FeatureExtractor
from PIL import Image
from torch import Tensor
from torch.nn import AdaptiveAvgPool2d, Conv2d, Module


class DummyCustomModel(Module):
    def __init__(self, feature_dim: int = 2):
        super().__init__()
        self.feature_dim = feature_dim
        self.conv = Conv2d(in_channels=3, out_channels=feature_dim, kernel_size=2)
        self.global_pool = AdaptiveAvgPool2d(output_size=(1, 1))

    def num_features(self) -> int:
        return self.feature_dim

    def forward_features(self, x: Tensor) -> Tensor:
        return self.conv(x)

    def forward_pool(self, x: Tensor) -> Tensor:
        return self.global_pool(x)


def get_model() -> Module:
    return DummyCustomModel()


def get_feature_extractor(model: Module | None = None) -> FeatureExtractor:
    if model is None:
        model = get_model()
    return feature_extractor_api.get_feature_extractor_cls(model=model)(model=model)


def get_embedding_model(model: Module | None = None) -> EmbeddingModel:
    return EmbeddingModel(feature_extractor=get_feature_extractor(model=model))


def get_method(model: Module | None = None) -> Method:
    return SimCLR(
        method_args=SimCLRArgs(),
        embedding_model=get_embedding_model(model=model),
        batch_size_per_device=2,
    )


def get_checkpoint() -> Checkpoint:
    model = get_model()
    embedding_model = get_embedding_model(model=model)
    method = get_method(model=model)
    return Checkpoint(
        state_dict=method.state_dict(),
        lightly_train=CheckpointLightlyTrain.from_now(
            models=CheckpointLightlyTrainModels(
                model=model, embedding_model=embedding_model
            )
        ),
    )


def create_image(path: Path, width: int = 128, height: int = 128) -> None:
    img_np = np.random.uniform(0, 255, size=(width, height, 3))
    img = Image.fromarray(img_np.astype(np.uint8))
    img.save(path)


def create_images(image_dir: Path, n: int = 10) -> None:
    image_dir.mkdir(parents=True, exist_ok=True)
    for i in range(n):
        create_image(
            path=image_dir / f"{i}.png",
        )


def create_video(video_path: Path, n_frames: int = 10) -> None:
    extract_video_frames.assert_ffmpeg_is_installed()
    frame_dir = video_path.parent / video_path.stem
    frame_dir.mkdir(parents=True, exist_ok=True)
    create_images(image_dir=frame_dir, n=n_frames)
    cmd = [
        "ffmpeg",
        "-framerate",
        "1",
        "-i",
        str(frame_dir / "%d.png"),
        "-c:v",
        "libx264",
        "-vf",
        "fps=1",
        "-pix_fmt",
        "yuv420p",
        str(video_path),
    ]
    subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)


def create_videos(
    videos_dir: Path, n_videos: int = 4, n_frames_per_video: int = 10
) -> None:
    extract_video_frames.assert_ffmpeg_is_installed()
    videos_dir.mkdir(parents=True, exist_ok=True)
    for i in range(n_videos):
        create_video(
            video_path=videos_dir / f"video_{i}.mp4",
            n_frames=n_frames_per_video,
        )
