#!/usr/bin/env python
# ******************************************************************************
# Copyright 2023 Brainchip Holdings Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
"""
Layer decorators.
"""

import inspect
import tensorflow as tf

from .quantization_params import get_quantization_params
from .quantizers import WeightQuantizer, AlignedWeightQuantizer, OutputQuantizer
from .recorders import TensorRecorder
from ..tensors import FixedPoint, MAX_BUFFER_BITWIDTH, BASE_BUFFER_BITWIDTH

# Mapper to match float layer with its quantized version, populated thanks to
# `register_quantize_target` decorator on quantized layers.
_GLOBAL_LAYER_TO_QLAYER = {}

# List of quantized layers that cannot rescale their outputs (no output quantizer)
_GLOBAL_NO_OUTPUT_QUANTIZER = []

# List of quantized layers that require aligned inputs
_GLOBAL_ALIGNED_INPUTS = []


def register_quantize_target(base_layer):
    """Register the current class as a target for the quantization of a keras layer.

    This decorator injects the decorated class into the _GLOBAL_LAYER_TO_QLAYER dictionary, so that
    it registered as the quantization target for the provided `base_layer`.

    Args:
        base_layer (keras.Layer, list): the origin layer (or list of) that should be quantized as
            the current class

    Returns:
        Callable: a decorator that registers the decorated class
    """
    def _register_target(target, arg):
        """Register the current class as a target for the quantization of a keras layer.

        Args:
            target (keras.Layer): the target to register
            arg (Cls): the current class to register
        """
        base_class_name = target.__name__

        if not inspect.isclass(arg):
            raise ValueError("Can only register class objects with 'register_quantize_target'.")

        if base_class_name in _GLOBAL_LAYER_TO_QLAYER:
            raise ValueError(f"{base_class_name} has already been registered to "
                             f"{_GLOBAL_LAYER_TO_QLAYER[base_class_name]}.")

        _GLOBAL_LAYER_TO_QLAYER.update({base_class_name: arg})

    def decorator(arg):
        targets = base_layer if isinstance(base_layer, list) else [base_layer]
        for target in targets:
            _register_target(target, arg)
        return arg

    return decorator


def register_no_output_quantizer(arg):
    """Register the decorated class as not able to rescale its outputs.

    _GLOBAL_NO_OUTPUT_QUANTIZER is populated with the quantized layer type.

    Args:
        arg (Cls): the class to register

    Returns:
        Callable: a decorator that registers the decorated class
    """
    if not inspect.isclass(arg):
        raise ValueError("Can only register class objects with 'register_no_output_quantizer'.")
    _GLOBAL_NO_OUTPUT_QUANTIZER.append(arg)
    return arg


def register_aligned_inputs(arg):
    """Register the decorated class as requiring aligned inputs.

    _GLOBAL_ALIGNED_INPUTS is populated with the quantized layer type.

    Args:
        arg (Cls): the class to register

    Returns:
        Callable: a decorator that registers the decorated class
    """
    if not inspect.isclass(arg):
        raise ValueError("Can only register class objects with 'register_aligned_inputs'.")
    _GLOBAL_ALIGNED_INPUTS.append(arg)
    return arg


def rescale_outputs(call):
    """ Decorator to rescale the outputs produced by a layer 'call' function.

    Args:
        call (Callable): the decorated call function

    Returns:
        Callable: the decorated function
    """
    def decorator(self, inputs):
        outputs = call(self, inputs)
        if self.out_quantizer is not None:
            outputs = self.out_quantizer(outputs)
        return outputs
    return decorator


def tensor_inputs(supported):
    """ Decorator to check the input tensors passed to a layer `call`.

    Args:
        supported (list): list of supported input types

    Returns:
        Callable: the decorated function
    """
    def _get_weight_quantizer_bitwidth(obj):
        quantizer_attr = [getattr(obj, name) for name in dir(obj) if name.endswith('quantizer')]
        for quant in quantizer_attr:
            if isinstance(quant, WeightQuantizer):
                return quant.bitwidth
        # Some layers only have an AlignedWeightQuantizer
        for quant in quantizer_attr:
            if isinstance(quant, AlignedWeightQuantizer):
                return quant.bitwidth
        return None

    def decorator(call):
        if not isinstance(supported, (list, tuple)):
            raise TypeError(f"'supported' must be a list or a tuple, received {type(supported)}.")

        def check_inputs(self, inputs):
            # Raise an error if the inputs are not in the 'supported' types
            if not isinstance(inputs, tuple(supported)):
                raise TypeError(f"{self.__class__.__name__} only accepts {supported} inputs. "
                                f"Receives {type(inputs)} inputs.")

            if isinstance(inputs, tf.Tensor):
                # Assume the inputs are 8 bit integer stored as float, which is the only tf.Tensor
                # inputs that are allowed
                inputs = FixedPoint(inputs, 8, 0)

            if getattr(self, 'buffer_bitwidth', None) is not None:
                # Adapt buffer_bitwidth: excluding the effect of accumulations (sums or negations)
                # because the training regularization tends to center them around zero, worst case
                # is: input_bits * 2 (with expand) + weights_bits + scale_bits,
                # so as long as inputs or weights are not 4bits, buffer_bitwidth must be
                # MAX_BUFFER_BITWIDTH - 1.
                if self.buffer_bitwidth != MAX_BUFFER_BITWIDTH - 1:
                    weight_bits = _get_weight_quantizer_bitwidth(self)
                    if weight_bits and (inputs.value_bits > 4 or weight_bits > 4):
                        self.buffer_bitwidth = MAX_BUFFER_BITWIDTH - 1

                if (isinstance(inputs, FixedPoint) and not inputs.per_tensor
                        and self.__class__.__name__ != 'QuantizedDepthwiseConv2D'):
                    # Expand the inputs to a higher bitwidth to avoid saturation and align them.
                    # Depthwise layers do not require input_shift because each channel is
                    # handled by a single filter.
                    inputs, shift = inputs.expand(self.buffer_bitwidth)
                    if getattr(self, 'input_shift', None) is None:
                        # Add object that will store the shift values.
                        # From Keras documentation, any variable creation taking place in call
                        # should be wrapped with tf.init_scope
                        with tf.init_scope():
                            self.input_shift = TensorRecorder(name="input_shift")
                    self.input_shift(shift)
                else:
                    # Promote inputs to avoid a saturation
                    inputs = inputs.promote(self.buffer_bitwidth)
            return call(self, inputs)
        return check_inputs
    return decorator


def neural_layer_init(separable):
    """ Decorator to initialize a neural layer.

    Args:
        separable (bool): True if the layer has separable weights.

    Returns:
        Callable: the decorated function
    """
    def decorator(init):
        def wrapper(self, *args, quant_config=None, **kwargs):
            # First call super().__init__
            super_init = getattr(super(type(self), self), "__init__")
            # Handle special parameter "padding_value" that must not be passed to super()
            updated_kwargs = kwargs.copy()
            updated_kwargs.pop('padding_value', None)
            super_init(*args, **updated_kwargs)

            # Then start neural layer init
            quant_config = quant_config or dict()
            self.quant_config = quant_config
            default_weight_bits = get_quantization_params().weight_bits

            # Use quant_config to build quantizers
            if separable:
                # Separable layer has two weights quantizers to handle different max values
                dw_weight_quantizer_cfg = quant_config.get("dw_weight_quantizer",
                                                           {"bitwidth": default_weight_bits})
                # Separable depthwise weights are quantized per tensor
                dw_weight_quantizer_cfg.update({'axis': "per-tensor"})
                self.quant_config['dw_weight_quantizer'] = dw_weight_quantizer_cfg
                self.dw_weight_quantizer = WeightQuantizer(name="dw_weight_quantizer",
                                                           **dw_weight_quantizer_cfg)
                pw_weight_quantizer_cfg = quant_config.get("pw_weight_quantizer",
                                                           {"bitwidth": default_weight_bits})
                self.quant_config['pw_weight_quantizer'] = pw_weight_quantizer_cfg
                self.pw_weight_quantizer = WeightQuantizer(name="pw_weight_quantizer",
                                                           **pw_weight_quantizer_cfg)
            else:
                weight_quantizer_cfg = quant_config.get("weight_quantizer",
                                                        {"bitwidth": default_weight_bits})
                self.quant_config['weight_quantizer'] = weight_quantizer_cfg
                self.weight_quantizer = WeightQuantizer(name="weight_quantizer",
                                                        **weight_quantizer_cfg)

            # Finalize output and bias quantizers
            out_quant_cfg = quant_config.get("output_quantizer", False)
            if out_quant_cfg:
                self.out_quantizer = OutputQuantizer(name="output_quantizer", **out_quant_cfg)
            else:
                self.out_quantizer = None
            if self.use_bias:
                bias_quantizer_cfg = quant_config.get("bias_quantizer", {})
                self.bias_quantizer = AlignedWeightQuantizer(name="bias_quantizer",
                                                             **bias_quantizer_cfg)
            self.buffer_bitwidth = BASE_BUFFER_BITWIDTH - 1

            # Baseline init
            init(self, *args, **kwargs)
        return wrapper
    return decorator
