from __future__ import annotations

import re
from pathlib import Path

import pytest
import torch
from lightly_train._checkpoint import Checkpoint
from lightly_train._commands import export
from omegaconf import OmegaConf
from torch import Tensor

from .. import helpers


def test_export__torch_state_dict(tmp_path: Path) -> None:
    """Check that exporting a model's state dict works as expected."""
    ckpt_path, ckpt = _get_checkpoint(tmp_path)
    model = ckpt.lightly_train.models.model
    embedding_model = ckpt.lightly_train.models.embedding_model
    part_expected = [
        ("model", model.state_dict()),
        ("embedding_model", embedding_model.state_dict()),
    ]

    for part, expected in part_expected:
        out_path = tmp_path / f"{part}.pt"
        export.export(
            out=out_path,
            checkpoint=ckpt_path,
            part=part,
            format="torch_state_dict",
        )
        _assert_state_dict_equal(torch.load(out_path), expected)


def test_export__torch_model(tmp_path: Path) -> None:
    """Check that exporting a model works as expected."""
    ckpt_path, ckpt = _get_checkpoint(tmp_path)
    model = ckpt.lightly_train.models.model
    embedding_model = ckpt.lightly_train.models.embedding_model
    part_expected = [
        ("model", model),
        ("embedding_model", embedding_model),
    ]

    for part, expected in part_expected:
        out_path = tmp_path / f"{part}.pt"
        export.export(
            out=out_path,
            checkpoint=ckpt_path,
            part=part,
            format="torch_model",
        )
        loaded_model = torch.load(out_path)
        assert isinstance(loaded_model, type(expected))
        _assert_state_dict_equal(loaded_model.state_dict(), expected.state_dict())


def test_export__invalid_part() -> None:
    """Check that an error is raised when an invalid part is provided."""
    with pytest.raises(
        ValueError,
        match=re.escape(
            "Invalid model part: 'invalid_part'. Valid parts are: "
            "['model', 'embedding_model']"
        ),
    ):
        export.export(
            out="out.pt",
            checkpoint="checkpoint.pt",
            part="invalid_part",
            format="torch_state_dict",
        )


def test_export__invalid_format() -> None:
    """Check that an error is raised when an invalid format is provided."""
    with pytest.raises(
        ValueError,
        match=re.escape(
            "Invalid model format: 'invalid_format'. Valid formats are: "
            "['torch_model', 'torch_state_dict']"
        ),
    ):
        export.export(
            out="out.pt",
            checkpoint="checkpoint.pt",
            part="model",
            format="invalid_format",
        )


def test_export_from_config(tmp_path: Path) -> None:
    ckpt_path, ckpt = _get_checkpoint(tmp_path)
    out_path = tmp_path / "model.pt"
    model = ckpt.lightly_train.models.model
    config = OmegaConf.create(
        dict(
            checkpoint=str(ckpt_path),
            out=str(out_path),
            part="model",
            format="torch_state_dict",
        )
    )
    export.export_from_config(config=config)
    _assert_state_dict_equal(torch.load(out_path), model.state_dict())


def _assert_state_dict_equal(a: dict[str, Tensor], b: dict[str, Tensor]) -> None:
    assert a.keys() == b.keys()
    for key in a.keys():
        assert torch.allclose(a[key], b[key])


def _get_checkpoint(tmp_path: Path) -> tuple[Path, Checkpoint]:
    checkpoint = helpers.get_checkpoint()
    ckpt_path = tmp_path / "last.ckpt"
    checkpoint.save(ckpt_path)
    return ckpt_path, checkpoint
