#!/usr/bin/env python
# ******************************************************************************
# Copyright 2022 Brainchip Holdings Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
"""
QuantizedSeparableConv2D layer definition.
"""

import tensorflow as tf
from keras.layers import SeparableConv2D
from keras.utils import conv_utils
from keras import backend

from ..tensors import FixedPoint, MAX_BUFFER_BITWIDTH, QFloat
from .layers import deserialize_quant_object, Calibrable, CalibrableVariable


__all__ = ["QuantizedSeparableConv2D", "SeparableConv2DTranspose",
           "QuantizedSeparableConv2DTranspose"]


@tf.keras.utils.register_keras_serializable()
class QuantizedSeparableConv2D(Calibrable, SeparableConv2D):
    """ A separable convolutional layer that operates on quantized inputs and weights.
    """

    def __init__(self, *args, quant_config={}, **kwargs):
        if 'dilation_rate' in kwargs:
            if kwargs['dilation_rate'] not in [1, [1, 1], (1, 1)]:
                raise ValueError("Keyword argument 'dilation_rate' is not supported in \
                                 QuantizedSeparableConv2D.")
        if 'depth_multiplier' in kwargs:
            if kwargs['depth_multiplier'] != 1:
                raise ValueError("Keyword argument 'depth_multiplier' is not supported in \
                                 QuantizedSeparableConv2D.")

        super().__init__(*args, **kwargs)
        self.quant_config = quant_config

        self.out_quantizer = deserialize_quant_object(
            self.quant_config, "output_quantizer", False)

        # Separable layer has two weights quantizers to handle different max values
        self.dw_weight_quantizer = deserialize_quant_object(
            self.quant_config, "dw_weight_quantizer", True)
        self.pw_weight_quantizer = deserialize_quant_object(
            self.quant_config, "pw_weight_quantizer", True)

        if self.use_bias:
            self.bias_quantizer = deserialize_quant_object(
                self.quant_config, "bias_quantizer", True)

        self.buffer_bitwidth = self.quant_config.get("buffer_bitwidth", MAX_BUFFER_BITWIDTH) - 1
        assert self.buffer_bitwidth > 0, "The buffer_bitwidth must be a strictly positive integer."
        self.intermediate_quantizer = deserialize_quant_object(
            self.quant_config, "intermediate_quantizer", False)

        if self.intermediate_quantizer:
            intermediate_quantizer_axis = self.intermediate_quantizer.get_config()['axis']
            if intermediate_quantizer_axis != "per-tensor":
                raise ValueError("Only supporting 'per-tensor' intermediate quantizer. Received "
                                 f"{intermediate_quantizer_axis}.")
        # Add objects that will store the shift values.
        self.input_shift = CalibrableVariable()
        if self.use_bias:
            self.bias_shift = CalibrableVariable()
            self.output_shift = CalibrableVariable()

    def call(self, inputs, training=None):
        # raise an error if the inputs are not FixedPoint or tf.Tensor
        if not isinstance(inputs, (FixedPoint, tf.Tensor)):
            raise TypeError(f"QuantizedSeparableConv2D only accepts FixedPoint\
                               or tf.Tensor inputs. Receives {type(inputs)} inputs.")

        # Quantize the weights
        depthwise_kernel = self.dw_weight_quantizer(self.depthwise_kernel, training)
        pointwise_kernel = self.pw_weight_quantizer(self.pointwise_kernel, training)

        if isinstance(inputs, tf.Tensor):
            # Assume the inputs are integer stored as float, which is the only tf.Tensor
            # inputs that are allowed
            inputs = FixedPoint.quantize(inputs, 0, 8)

        inputs, shift = inputs.promote(self.buffer_bitwidth).align()
        self.input_shift(shift)
        dw_outputs_q = backend.depthwise_conv2d(
            inputs,
            depthwise_kernel,
            strides=self.strides,
            padding=self.padding,
            dilation_rate=self.dilation_rate,
            data_format=self.data_format)

        if self.intermediate_quantizer is not None:
            dw_outputs_q = self.intermediate_quantizer(dw_outputs_q, training)
            dw_outputs_q = dw_outputs_q.promote(self.buffer_bitwidth)

        outputs = tf.nn.convolution(
            dw_outputs_q,
            pointwise_kernel,
            strides=[1, 1, 1, 1],
            padding='VALID',
            data_format=conv_utils.convert_data_format(self.data_format, ndim=4))

        if self.use_bias:
            bias = self.bias_quantizer(self.bias, training).promote(self.buffer_bitwidth)

            # Align intermediate outputs and biases before adding them
            outputs, shift = outputs.align(bias)
            self.output_shift(shift)
            bias, shift = bias.align(outputs)
            self.bias_shift(shift)
            outputs = tf.add(outputs, bias)

        if self.out_quantizer is not None:
            outputs = self.out_quantizer(outputs, training)
        return outputs

    def get_config(self):
        config = super().get_config()
        config["quant_config"] = self.quant_config
        return config


@tf.keras.utils.register_keras_serializable()
class SeparableConv2DTranspose(SeparableConv2D):
    """ A transposed separable convolutional layer.

    It performs a transposed depthwise convolution on inputs followed by a standard pointwise
    operation.
    """

    def __init__(self, *args, **kwargs):
        if 'dilation_rate' in kwargs:
            if kwargs['dilation_rate'] not in [1, [1, 1], (1, 1)]:
                raise ValueError("Keyword argument 'dilation_rate' is not supported in "
                                 "SeparableConv2DTranspose.")
        if 'depth_multiplier' in kwargs:
            if kwargs['depth_multiplier'] != 1:
                raise ValueError("Keyword argument 'depth_multiplier' is not supported in "
                                 "SeparableConv2DTranspose.")
        # Limit supported stride to 2. Standard separable should be used for stride 1 and greater
        # strides are not supported.
        if 'strides' in kwargs:
            if kwargs['strides'] not in [2, [2, 2], (2, 2)]:
                raise ValueError(f"Only supported stride is 2. Received {kwargs['strides']}.")
        # Also limit padding to 'same'
        if 'padding' in kwargs:
            if kwargs['padding'] != 'same':
                raise ValueError(f"Only supported padding is 'same'. Received {kwargs['padding']}.")
        super().__init__(*args, **kwargs)

    def call(self, inputs):
        # Infer the dynamic output shape
        inputs_shape = tf.shape(inputs)
        out_height = conv_utils.deconv_output_length(inputs_shape[1],
                                                     self.kernel_size[0],
                                                     padding=self.padding,
                                                     stride=self.strides[0],
                                                     dilation=self.dilation_rate[0])
        out_width = conv_utils.deconv_output_length(inputs_shape[2],
                                                    self.kernel_size[1],
                                                    padding=self.padding,
                                                    stride=self.strides[1],
                                                    dilation=self.dilation_rate[1])
        output_shape = tf.stack((inputs_shape[0], out_height, out_width, 1))

        # Inputs and kernels must be transposed to have their channel dimension first because the
        # tf.vectorized_map call that follows will unpack them on dimension 0. The channel dimension
        # is virtually restored using expand_dims so that elements have the appropriate shape for
        # the conv2d_transpose call (with a channel dimension of 1 which is expected in the
        # depthwise process).
        inputs_channel_first = tf.transpose(inputs, (3, 0, 1, 2))
        inputs_channel_first = tf.expand_dims(inputs_channel_first, -1)
        kernel_channel_first = tf.transpose(self.depthwise_kernel, (2, 0, 1, 3))
        kernel_channel_first = tf.expand_dims(kernel_channel_first, -2)

        dw_outputs = tf.vectorized_map(
            lambda x: backend.conv2d_transpose(x[0],
                                               x[1],
                                               output_shape=output_shape,
                                               strides=self.strides,
                                               padding=self.padding),
            (inputs_channel_first, kernel_channel_first))
        dw_outputs = tf.transpose(tf.squeeze(dw_outputs, axis=-1), (1, 2, 3, 0))

        # Pointwise operation
        outputs = tf.nn.convolution(
            dw_outputs,
            self.pointwise_kernel,
            strides=[1, 1, 1, 1],
            padding='VALID')

        if self.use_bias:
            outputs = tf.add(outputs, self.bias)
        return outputs


@tf.keras.utils.register_keras_serializable()
class QuantizedSeparableConv2DTranspose(Calibrable, SeparableConv2DTranspose):
    """ A transposed separable convolutional layer that operates on quantized inputs and weights.
    """

    def __init__(self, *args, quant_config={}, **kwargs):
        super().__init__(*args, **kwargs)
        self.quant_config = quant_config

        self.out_quantizer = deserialize_quant_object(self.quant_config, "output_quantizer", False)

        # Transposed Separable layer has two weights quantizers to handle different max values
        self.dw_weight_quantizer = deserialize_quant_object(
            self.quant_config, "dw_weight_quantizer", True)
        self.pw_weight_quantizer = deserialize_quant_object(
            self.quant_config, "pw_weight_quantizer", True)

        # Depthwise quantizer must be per-tensor
        if "axis" in self.quant_config["dw_weight_quantizer"]:
            if self.quant_config["dw_weight_quantizer"]["axis"] != "per-tensor":
                raise ValueError("Only supporting 'per-tensor' depthwise quantizer. Received "
                                 f"{self.quant_config['dw_weight_quantizer']['axis']}.")

        if self.use_bias:
            self.bias_quantizer = deserialize_quant_object(
                self.quant_config, "bias_quantizer", True)

        self.buffer_bitwidth = self.quant_config.get("buffer_bitwidth", MAX_BUFFER_BITWIDTH) - 1
        assert self.buffer_bitwidth > 0, "The buffer_bitwidth must be a strictly positive integer."
        self.intermediate_quantizer = deserialize_quant_object(
            self.quant_config, "intermediate_quantizer", False)

        if self.intermediate_quantizer:
            intermediate_quantizer_axis = self.intermediate_quantizer.get_config()['axis']
            if intermediate_quantizer_axis != "per-tensor":
                raise ValueError("Only supporting 'per-tensor' intermediate quantizer. Received "
                                 f"{intermediate_quantizer_axis}.")

        # Add objects that will store the shift values.
        self.input_shift = CalibrableVariable()
        self.pw_shift = CalibrableVariable()
        self.bias_shift = CalibrableVariable()

    def call(self, inputs, training=None):
        # raise an error if the inputs are not FixedPoint or tf.Tensor
        if not isinstance(inputs, (FixedPoint, tf.Tensor)):
            raise TypeError("QuantizedSeparableConv2DTranspose only accepts FixedPoint or "
                            f"tf.Tensor inputs. Receives {type(inputs)} inputs.")

        if isinstance(inputs, tf.Tensor):
            # Assume the inputs are integer stored as float, which is the only tf.Tensor
            # inputs that are allowed
            inputs = FixedPoint.quantize(inputs, 0, 8)

        # Promote inputs to avoid saturation, then align them
        inputs, shift = inputs.promote(self.buffer_bitwidth).align()
        self.input_shift(shift)

        # Infer the dynamic output shape
        inputs_shape = tf.shape(inputs)
        out_height = conv_utils.deconv_output_length(inputs_shape[1],
                                                     self.kernel_size[0],
                                                     padding=self.padding,
                                                     stride=self.strides[0],
                                                     dilation=self.dilation_rate[0])
        out_width = conv_utils.deconv_output_length(inputs_shape[2],
                                                    self.kernel_size[1],
                                                    padding=self.padding,
                                                    stride=self.strides[1],
                                                    dilation=self.dilation_rate[1])
        output_shape = tf.stack((inputs_shape[0], out_height, out_width, 1))

        # Inputs and kernels must be transposed to have their channel dimension first because the
        # tf.vectorized_map call that follows will unpack them on dimension 0. The channel dimension
        # is virtually restored using expand_dims so that elements have the appropriate shape for
        # the conv2d_transpose call (with a channel dimension of 1 which is expected in the
        # depthwise process).
        inputs_channel_first = tf.transpose(inputs, (3, 0, 1, 2))
        inputs_channel_first = tf.expand_dims(inputs_channel_first, -1)
        kernel_channel_first = tf.transpose(self.depthwise_kernel, (2, 0, 1, 3))
        kernel_channel_first = tf.expand_dims(kernel_channel_first, -2)

        # Quantize the depthwise kernels
        depthwise_kernel = self.dw_weight_quantizer(kernel_channel_first, training)

        # Perform the depthwise operation on values using conv2d_transpose on each channel
        dw_values = tf.vectorized_map(
            lambda x: backend.conv2d_transpose(x[0],
                                               x[1],
                                               output_shape=output_shape,
                                               strides=self.strides,
                                               padding=self.padding),
            (inputs_channel_first.values, depthwise_kernel.values))
        dw_values = tf.transpose(tf.squeeze(dw_values, axis=-1), (1, 2, 3, 0))

        if isinstance(depthwise_kernel, QFloat):
            # Multiply by the scale
            dw_values *= depthwise_kernel.scales.values

        # Build a new FixedPoint
        if isinstance(depthwise_kernel, FixedPoint):
            filters_frac_bits = depthwise_kernel.frac_bits
        else:
            filters_frac_bits = depthwise_kernel.scales.frac_bits
        dw_outputs = FixedPoint(dw_values, inputs.frac_bits + filters_frac_bits, inputs.value_bits)

        if self.intermediate_quantizer is not None:
            dw_outputs_q = self.intermediate_quantizer(dw_outputs, training)
            dw_outputs_q = dw_outputs_q.promote(self.buffer_bitwidth)
        else:
            dw_outputs_q = dw_outputs

        # Quantize the pointwise kernel
        pointwise_kernel = self.pw_weight_quantizer(self.pointwise_kernel, training)

        # Pointwise operation
        outputs = tf.nn.convolution(
            dw_outputs_q,
            pointwise_kernel,
            strides=[1, 1, 1, 1],
            padding='VALID')

        if self.use_bias:
            # Quantize biases
            bias = self.bias_quantizer(self.bias, training)
            # Align intermediate outputs and biases before adding them
            outputs, shift = outputs.align(bias)
            self.pw_shift(shift)
            bias, shift = bias.promote(self.buffer_bitwidth).align(outputs)
            self.bias_shift(shift)
            outputs = tf.add(outputs, bias)

        if self.out_quantizer is not None:
            outputs = self.out_quantizer(outputs, training)
        return outputs

    def get_config(self):
        config = super().get_config()
        config["quant_config"] = self.quant_config
        return config
