from collections.abc import Callable, Iterable

import torch

from ...tensorlist import TensorList

from ...core import OptimizerModule, _Chainable


class Alpha(OptimizerModule):
    """Multiplies update by the learning rate, won't get picked up by learning rate schedulers."""
    def __init__(self, alpha = 1e-3):
        defaults = dict(alpha = alpha)
        super().__init__(defaults)

    @torch.no_grad
    def _update(self, vars, ascent):
        # multiply ascent direction by lr in-place
        lr = self.get_group_key('alpha')
        ascent *= lr
        return ascent

class Clone(OptimizerModule):
    """Clones the update. Some modules update ascent in-place, so this may be
    useful if you need to preserve it."""
    def __init__(self):
        super().__init__({})

    @torch.no_grad
    def _update(self, vars, ascent): return ascent.clone()

class Identity(OptimizerModule):
    """Does nothing."""
    def __init__(self, *args, **kwargs):
        super().__init__({})

    @torch.no_grad
    def _update(self, vars, ascent): return ascent

class Lambda(OptimizerModule):
    """Applies a function to the ascent direction.
    The function must take a TensorList as the argument, and return the modified tensorlist.

    Args:
        f (Callable): function
    """
    def __init__(self, f: Callable[[TensorList], TensorList]):
        super().__init__({})
        self.f = f

    @torch.no_grad()
    def _update(self, vars, ascent): return self.f(ascent)

class Grad(OptimizerModule):
    """Uses gradient as the update. This is useful for chains."""
    def __init__(self):
        super().__init__({})

    @torch.no_grad
    def _update(self, vars, ascent):
        ascent = vars.ascent = vars.maybe_compute_grad_(self.get_params())
        return ascent

class Zeros(OptimizerModule):
    def __init__(self):
        super().__init__({})

    @torch.no_grad
    def _update(self, vars, ascent):
        return ascent.zeros_like()

class Fill(OptimizerModule):
    def __init__(self, value):
        super().__init__({"value": value})

    @torch.no_grad
    def _update(self, vars, ascent):
        return ascent.fill(self.get_group_key('value'))


class GradToUpdate(OptimizerModule):
    """sets gradient and .grad attributes to current update"""
    def __init__(self):
        super().__init__({})

    def _update(self, vars, ascent):
        vars.set_grad_(ascent, self.get_params())
        return ascent

class MakeClosure(OptimizerModule):
    """Makes a closure that sets `.grad` attribute to the update generated by `modules`"""
    def __init__(self, modules: _Chainable):
        super().__init__({})
        self._set_child_('modules', modules)

    def step(self, vars):
        if vars.closure is None: raise ValueError("MakeClosure requires a closure")

        params = self.get_params()
        orig_closure = vars.closure
        orig_state = vars.copy(True)

        def new_closure(backward = True):
            if backward:
                cloned_state = orig_state.copy(True)
                g = self.children['modules'].return_ascent(cloned_state)
                params.set_grad_(g)
                return cloned_state.get_loss()

            else:
                return orig_closure(False)

        vars.closure = new_closure # type:ignore
        return self._update_params_or_step_with_next(vars)

