from typing import Literal
from collections.abc import Mapping, Callable
from functools import partial
import numpy as np
import torch

import nlopt
from ...core import TensorListOptimizer, _ClosureType
from ...tensorlist import TensorList

_ALGOS_LITERAL = Literal[
    "GN_DIRECT",  # = _nlopt.GN_DIRECT
    "GN_DIRECT_L",  # = _nlopt.GN_DIRECT_L
    "GN_DIRECT_L_RAND",  # = _nlopt.GN_DIRECT_L_RAND
    "GN_DIRECT_NOSCAL",  # = _nlopt.GN_DIRECT_NOSCAL
    "GN_DIRECT_L_NOSCAL",  # = _nlopt.GN_DIRECT_L_NOSCAL
    "GN_DIRECT_L_RAND_NOSCAL",  # = _nlopt.GN_DIRECT_L_RAND_NOSCAL
    "GN_ORIG_DIRECT",  # = _nlopt.GN_ORIG_DIRECT
    "GN_ORIG_DIRECT_L",  # = _nlopt.GN_ORIG_DIRECT_L
    "GD_STOGO",  # = _nlopt.GD_STOGO
    "GD_STOGO_RAND",  # = _nlopt.GD_STOGO_RAND
    "LD_LBFGS_NOCEDAL",  # = _nlopt.LD_LBFGS_NOCEDAL
    "LD_LBFGS",  # = _nlopt.LD_LBFGS
    "LN_PRAXIS",  # = _nlopt.LN_PRAXIS
    "LD_VAR1",  # = _nlopt.LD_VAR1
    "LD_VAR2",  # = _nlopt.LD_VAR2
    "LD_TNEWTON",  # = _nlopt.LD_TNEWTON
    "LD_TNEWTON_RESTART",  # = _nlopt.LD_TNEWTON_RESTART
    "LD_TNEWTON_PRECOND",  # = _nlopt.LD_TNEWTON_PRECOND
    "LD_TNEWTON_PRECOND_RESTART",  # = _nlopt.LD_TNEWTON_PRECOND_RESTART
    "GN_CRS2_LM",  # = _nlopt.GN_CRS2_LM
    "GN_MLSL",  # = _nlopt.GN_MLSL
    "GD_MLSL",  # = _nlopt.GD_MLSL
    "GN_MLSL_LDS",  # = _nlopt.GN_MLSL_LDS
    "GD_MLSL_LDS",  # = _nlopt.GD_MLSL_LDS
    "LD_MMA",  # = _nlopt.LD_MMA
    "LN_COBYLA",  # = _nlopt.LN_COBYLA
    "LN_NEWUOA",  # = _nlopt.LN_NEWUOA
    "LN_NEWUOA_BOUND",  # = _nlopt.LN_NEWUOA_BOUND
    "LN_NELDERMEAD",  # = _nlopt.LN_NELDERMEAD
    "LN_SBPLX",  # = _nlopt.LN_SBPLX
    "LN_AUGLAG",  # = _nlopt.LN_AUGLAG
    "LD_AUGLAG",  # = _nlopt.LD_AUGLAG
    "LN_AUGLAG_EQ",  # = _nlopt.LN_AUGLAG_EQ
    "LD_AUGLAG_EQ",  # = _nlopt.LD_AUGLAG_EQ
    "LN_BOBYQA",  # = _nlopt.LN_BOBYQA
    "GN_ISRES",  # = _nlopt.GN_ISRES
    "AUGLAG",  # = _nlopt.AUGLAG
    "AUGLAG_EQ",  # = _nlopt.AUGLAG_EQ
    "G_MLSL",  # = _nlopt.G_MLSL
    "G_MLSL_LDS",  # = _nlopt.G_MLSL_LDS
    "LD_SLSQP",  # = _nlopt.LD_SLSQP
    "LD_CCSAQ",  # = _nlopt.LD_CCSAQ
    "GN_ESCH",  # = _nlopt.GN_ESCH
    "GN_AGS",  # = _nlopt.GN_AGS
]

def _ensure_float(x):
    if isinstance(x, torch.Tensor): return x.detach().cpu().item()
    if isinstance(x, np.ndarray): return x.item()
    return float(x)

def _ensure_tensor(x):
    if isinstance(x, np.ndarray):
        x.setflags(write=True)
        return torch.from_numpy(x)
    return torch.tensor(x, dtype=torch.float32)

inf = float('inf')
class NLOptOptimizer(TensorListOptimizer):
    """Use nlopt as pytorch optimizer, with gradient supplied by pytorch autograd.
    Note that this performs full minimization on each step,
    so usually you would want to perform a single step, although performing multiple steps will refine the
    solution.

    Some algorithms are buggy with numpy>=2.

    Args:
        params: iterable of parameters to optimize or dicts defining parameter groups.
        algorithm (int | _ALGOS_LITERAL): optimization algorithm from https://nlopt.readthedocs.io/en/latest/NLopt_Algorithms/
        maxeval (int | None):
            maximum allowed function evaluations, set to None to disable. But some stopping criterion
            must be set otherwise nlopt will run forever.
        lb (float | None, optional): optional lower bounds, some algorithms require this. Defaults to None.
        ub (float | None, optional): optional upper bounds, some algorithms require this. Defaults to None.
        stopval (float | None, optional): stop minimizing when an objective value ≤ stopval is found. Defaults to None.
        ftol_rel (float | None, optional): set relative tolerance on function value. Defaults to None.
        ftol_abs (float | None, optional): set absolute tolerance on function value. Defaults to None.
        xtol_rel (float | None, optional): set relative tolerance on optimization parameters. Defaults to None.
        xtol_abs (float | None, optional): set absolute tolerances on optimization parameters. Defaults to None.
        maxtime (float | None, optional): stop when the optimization time (in seconds) exceeds maxtime. Defaults to None.
    """
    def __init__(
        self,
        params,
        algorithm: int | _ALGOS_LITERAL,
        maxeval: int | None,
        lb: float | None = None,
        ub: float | None = None,
        stopval: float | None = None,
        ftol_rel: float | None = None,
        ftol_abs: float | None = None,
        xtol_rel: float | None = None,
        xtol_abs: float | None = None,
        maxtime: float | None = None,
    ):
        defaults = dict(lb=lb, ub=ub)
        super().__init__(params, defaults)

        self.opt: nlopt.opt | None = None
        if isinstance(algorithm, str): algorithm = getattr(nlopt, algorithm.upper())
        self.algorithm: int = algorithm # type:ignore
        self.algorithm_name: str | None = None

        self.maxeval = maxeval; self.stopval = stopval
        self.ftol_rel = ftol_rel; self.ftol_abs = ftol_abs
        self.xtol_rel = xtol_rel; self.xtol_abs = xtol_abs
        self.maxtime = maxtime

        self._last_loss = None

    def _f(self, x: np.ndarray, grad: np.ndarray, closure: _ClosureType, params: TensorList):
        params.from_vec_(_ensure_tensor(x).to(params[0], copy=False))
        if grad.size > 0:
            with torch.enable_grad(): loss = closure()
            self._last_loss = _ensure_float(loss)
            grad[:] = params.ensure_grad_().grad.to_vec().reshape(grad.shape).detach().cpu().numpy()
            return self._last_loss

        self._last_loss = _ensure_float(closure(False))
        return self._last_loss

    @torch.no_grad
    def step(self, closure: _ClosureType): # pylint: disable = signature-differs

        params = self.get_params()

        # make bounds
        lb, ub = self.get_group_keys('lb', 'ub', cls=list)
        lower = []
        upper = []
        for p, l, u in zip(params, lb, ub):
            if l is None: l = -inf
            if u is None: u = inf
            lower.extend([l] * p.numel())
            upper.extend([u] * p.numel())

        x0 = params.to_vec().detach().cpu().numpy()

        self.opt = nlopt.opt(self.algorithm, x0.size)
        self.opt.set_min_objective(partial(self._f, closure = closure, params = params))
        self.opt.set_lower_bounds(lower)
        self.opt.set_upper_bounds(upper)

        if self.maxeval is not None: self.opt.set_maxeval(self.maxeval)
        if self.stopval is not None: self.opt.set_stopval(self.stopval)
        if self.ftol_rel is not None: self.opt.set_ftol_rel(self.ftol_rel)
        if self.ftol_abs is not None: self.opt.set_ftol_abs(self.ftol_abs)
        if self.xtol_rel is not None: self.opt.set_xtol_rel(self.xtol_rel)
        if self.xtol_abs is not None: self.opt.set_xtol_abs(self.xtol_abs)
        if self.maxtime is not None: self.opt.set_maxtime(self.maxtime)

        x = self.opt.optimize(x0)
        params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
        return self._last_loss