# Copyright 2019 The FastEstimator Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Optional, Union

import tensorflow as tf
import tensorflow_addons as tfa
import torch

from fastestimator.backend.get_lr import get_lr


def set_lr(model: Union[tf.keras.Model, torch.nn.Module], lr: float, weight_decay: Optional[float] = None):
    """Set the learning rate of a given `model` generated by `fe.build`.

    This method can be used with TensorFlow models:
    ```python
    m = fe.build(fe.architecture.tensorflow.LeNet, optimizer_fn="adam")  # m.optimizer.lr == 0.001
    fe.backend.set_lr(m, lr=0.8)  # m.optimizer.lr == 0.8
    ```

    This method can be used with PyTorch models:
    ```python
    m = fe.build(fe.architecture.pytorch.LeNet, optimizer_fn="adam")  # m.optimizer.param_groups[-1]['lr'] == 0.001
    fe.backend.set_lr(m, lr=0.8)  # m.optimizer.param_groups[-1]['lr'] == 0.8
    ```

    Args:
        model: A neural network instance to modify.
        lr: The learning rate to assign to the `model`.
        weight_decay: The weight decay parameter, this is only relevant when using `tfa.DecoupledWeightDecayExtension`.

    Raises:
        ValueError: If `model` is an unacceptable data type.
    """
    assert hasattr(model, "fe_compiled") and model.fe_compiled, "set_lr only accept models from fe.build"
    if isinstance(model, tf.keras.Model):
        # when using decoupled weight decay like SGDW or AdamW, weight decay factor needs to change together with lr
        # see https://www.tensorflow.org/addons/api_docs/python/tfa/optimizers/DecoupledWeightDecayExtension for detail
        if isinstance(model.current_optimizer, tfa.optimizers.DecoupledWeightDecayExtension) or hasattr(
                model.current_optimizer, "inner_optimizer") and isinstance(
                    model.current_optimizer.inner_optimizer, tfa.optimizers.DecoupledWeightDecayExtension):
            if weight_decay is None:
                weight_decay = tf.keras.backend.get_value(model.current_optimizer.weight_decay) * lr / get_lr(model)
            tf.keras.backend.set_value(model.current_optimizer.weight_decay, weight_decay)
        tf.keras.backend.set_value(model.current_optimizer.lr, lr)
    elif isinstance(model, torch.nn.Module):
        for param_group in model.current_optimizer.param_groups:
            param_group['lr'] = lr
    else:
        raise ValueError("Unrecognized model instance {}".format(type(model)))
