from __future__ import annotations

from typing import Any

from lightly.models.utils import get_weight_decay_parameters
from torch.optim import Optimizer

from lightly_train._optim.adamw_args import AdamWArgs
from lightly_train._optim.optimizer_args import OptimizerArgs
from lightly_train._optim.optimizer_type import OptimizerType
from lightly_train._optim.trainable_modules import TrainableModules

OPTIM_TYPE_TO_ARGS = {args.type(): args for args in [AdamWArgs]}


def get_optimizer_type(
    optim_type: OptimizerType | str,
) -> OptimizerType:
    try:
        return OptimizerType(optim_type)
    except ValueError:
        raise ValueError(
            f"Invalid optimizer type: '{optim_type}'. Valid types are: "
            f"{[t.value for t in OptimizerType]}"
        )


def get_optimizer(
    optim_args: OptimizerArgs,
    trainable_modules: TrainableModules,
    lr_scale: float,
) -> Optimizer:
    params_weight_decay, params_no_weight_decay = get_weight_decay_parameters(
        modules=trainable_modules.modules
    )
    if trainable_modules.modules_no_weight_decay is not None:
        for m in trainable_modules.modules_no_weight_decay:
            params_no_weight_decay.extend(m.parameters())

    params: list[dict[str, Any]] = [{"name": "params", "params": params_weight_decay}]
    if params_no_weight_decay:
        params.append(
            {
                "name": "params_no_weight_decay",
                "params": params_no_weight_decay,
                "weight_decay": 0.0,
            }
        )
    return optim_args.get_optimizer(params=params, lr_scale=lr_scale)
