"""weight normalization layer from tensorflow addons
code is hightly borrowed from
https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/layers/wrappers.py

Differences:
    initialization between multi gpu
"""
from __future__ import absolute_import, division, print_function

import logging

import tensorflow as tf


class WeightNormalization(tf.keras.layers.Wrapper):
    """This wrapper reparameterizes a layer by decoupling the weight's
    magnitude and direction.

    Note:
        This speeds up convergence by improving the
        conditioning of the optimization problem.
        Weight Normalization: A Simple Reparameterization to Accelerate
        Training of Deep Neural Networks: https://arxiv.org/abs/1602.07868
        Tim Salimans, Diederik P. Kingma (2016)
        WeightNormalization wrapper works for keras and tf layers.

    Examples:

        >>> net = WeightNormalization(
        >>>   tf.keras.layers.Conv2D(2, 2, activation='relu'),
        >>>     input_shape=(32, 32, 3),
        >>>        data_init=True)(x)
        >>> net = WeightNormalization(
        >>>   tf.keras.layers.Conv2D(16, 5, activation='relu'),
        >>>     data_init=True)(net)
        >>> net = WeightNormalization(
        >>>   tf.keras.layers.Dense(120, activation='relu'),
        >>>     data_init=True)(net)
        >>> net = WeightNormalization(
        >>>   tf.keras.layers.Dense(n_classes),
        >>>     data_init=True)(net)

    Args:
      layer (tf.keras.layers.Layer): a layer instance.
      data_init (bool): If `True` use data dependent variable initialization

    Returns:
      tf.keras.layers.Layer: Wrapped Layer
   
    Raises:
      ValueError: If not initialized with a `Layer` instance.
      ValueError: If `Layer` does not contain a `kernel` of weights
      NotImplementedError: If `data_init` is True and running graph execution
    """

    def __init__(self, layer: tf.keras.layers, data_init: bool = True, **kwargs):
        super().__init__(layer, **kwargs)
        self.data_init = data_init
        self._track_trackable(layer, name="layer")
        self.is_rnn = isinstance(self.layer, tf.keras.layers.RNN)

        if self.data_init and self.is_rnn:
            logging.warning(
                "WeightNormalization: Using `data_init=True` with RNNs "
                "is advised against by the paper. Use `data_init=False`."
            )

    def build(self, input_shape):
        """Build `Layer`"""
        input_shape = tf.TensorShape(input_shape)
        self.input_spec = tf.keras.layers.InputSpec(shape=[None] + input_shape[1:])

        if not self.layer.built:
            self.layer.build(input_shape)

        kernel_layer = self.layer.cell if self.is_rnn else self.layer

        if not hasattr(kernel_layer, "kernel"):
            raise ValueError(
                "`WeightNormalization` must wrap a layer that"
                " contains a `kernel` for weights"
            )

        if self.is_rnn:
            kernel = kernel_layer.recurrent_kernel
        else:
            kernel = kernel_layer.kernel

        # The kernel's filter or unit dimension is -1
        self.layer_depth = int(kernel.shape[-1])
        self.kernel_norm_axes = list(range(kernel.shape.rank - 1))

        self.g = self.add_weight(
            name="g",
            shape=(self.layer_depth,),
            initializer="ones",
            dtype=kernel.dtype,
            trainable=True,
        )
        self.v = kernel

        self._initialized = self.add_weight(
            name="initialized",
            shape=None,
            initializer="zeros",
            dtype=tf.dtypes.bool,
            trainable=False,
        )

        if self.data_init:
            # Used for data initialization in self._data_dep_init.
            with tf.name_scope("data_dep_init"):
                layer_config = tf.keras.layers.serialize(self.layer)
                layer_config["config"]["trainable"] = False

                self._naked_clone_layer = tf.keras.layers.deserialize(layer_config)
                self._naked_clone_layer.build(input_shape)
                self._naked_clone_layer.set_weights(self.layer.get_weights())
                if not self.is_rnn:
                    self._naked_clone_layer.activation = None

        self.built = True

    def call(self, inputs):
        """Call `Layer`"""

        def _do_nothing():
            return tf.identity(self.g)

        def _update_weights():
            # Ensure we read `self.g` after _update_weights.
            with tf.control_dependencies(self._initialize_weights(inputs)):
                return tf.identity(self.g)

        g = tf.cond(self._initialized, _do_nothing, _update_weights)

        with tf.name_scope("compute_weights"):
            # Replace kernel by normalized weight variable.
            kernel = tf.nn.l2_normalize(self.v, axis=self.kernel_norm_axes) * g

            if self.is_rnn:
                self.layer.cell.recurrent_kernel = kernel
                update_kernel = tf.identity(self.layer.cell.recurrent_kernel)
            else:
                self.layer.kernel = kernel
                update_kernel = tf.identity(self.layer.kernel)

            # Ensure we calculate result after updating kernel.
            with tf.control_dependencies([update_kernel]):
                outputs = self.layer(inputs)
                return outputs

    def compute_output_shape(self, input_shape):
        return tf.TensorShape(self.layer.compute_output_shape(input_shape).as_list())

    def _initialize_weights(self, inputs):
        """Initialize weight g.
        The initial value of g could either from the initial value in v,
        or by the input value if self.data_init is True.
        """
        with tf.control_dependencies(
            [
                tf.debugging.assert_equal(  # pylint: disable=bad-continuation
                    self._initialized, False, message="The layer has been initialized."
                )
            ]
        ):
            if self.data_init:
                assign_tensors = self._data_dep_init(inputs)
            else:
                assign_tensors = self._init_norm()
            assign_tensors.append(self._initialized.assign(True))
            return assign_tensors

    def _init_norm(self):
        """Set the weight g with the norm of the weight vector."""
        with tf.name_scope("init_norm"):
            v_flat = tf.reshape(self.v, [-1, self.layer_depth])
            v_norm = tf.linalg.norm(v_flat, axis=0)
            g_tensor = self.g.assign(tf.reshape(v_norm, (self.layer_depth,)))
            return [g_tensor]

    def _calculate_moments(self, x, axes, keep_dims=True):
        """calculate moments between multi GPU
        this function is as same as tf.nn.moments(x, axes)
        if the code is running in Single-GPU environment.
        Sources:
            https://github.com/tensorflow/tensorflow/blob/v2.2.0/tensorflow/python/keras/layers/normalization_v2.py#L165-L204
        """
        y = tf.cast(x, tf.float32) if x.dtype == tf.float16 else x
        replica_ctx = tf.distribute.get_replica_context()
        if replica_ctx:
            # define each GPU's variables
            local_sum = tf.reduce_sum(y, axis=axes, keepdims=True)
            local_squared_sum = tf.reduce_sum(tf.math.square(y), axes, keepdims=True)
            batch_size = tf.cast(tf.shape(y)[0], tf.float32)

            # reduce all GPUs' variables
            y_sum, y_squared_sum, global_batch_size = replica_ctx.all_reduce(
                tf.distribute.ReduceOp.SUM, [local_sum, local_squared_sum, batch_size]
            )

            axes_vals = [(tf.shape(y))[i] for i in range(1, len(axes))]
            multiplier = tf.cast(tf.reduce_prod(axes_vals), tf.float32)
            multiplier = multiplier * global_batch_size
            mean = y_sum / multiplier
            y_squared_mean = y_squared_sum / multiplier
            # var = E(x^2) - E(x)^2
            variance = y_squared_mean - tf.square(mean)
        else:
            mean = tf.reduce_meam(y, axes, keepdims=True)
            variance = tf.reduce_mean(
                tf.squared_difference(y, tf.stop_gradient(mean)), axes, keepdims=True
            )
        if not keep_dims:
            mean = tf.squeeze(mean, axes)
            variance = tf.squeeze(variance, axes)
        if x.dtype == tf.float16:
            return (tf.cast(mean, tf.float16), tf.cast(variance, tf.float16))
        else:
            return (mean, variance)

    def _data_dep_init(self, inputs):
        """Data dependent initialization."""
        with tf.name_scope("data_dep_init"):
            # Generate data dependent init values
            x_init = self._naked_clone_layer(inputs)
            data_norm_axes = list(range(x_init.shape.rank - 1))

            m_init, v_init = self._calculate_moments(
                x_init, data_norm_axes, keep_dims=False
            )
            scale_init = 1.0 / tf.math.sqrt(v_init + 1e-10)

            # RNNs have fused kernels that are tiled
            # Repeat scale_init to match the shape of fused kernel
            # Note: This is only to support the operation,
            # the paper advises against RNN+data_dep_init
            if scale_init.shape[0] != self.g.shape[0]:
                rep = int(self.g.shape[0] / scale_init.shape[0])
                scale_init = tf.tile(scale_init, [rep])

            # Assign data dependent init values
            g_tensor = self.g.assign(self.g * scale_init)
            if hasattr(self.layer, "bias") and self.layer.bias is not None:
                bias_tensor = self.layer.bias.assign(-m_init * scale_init)
                return [g_tensor, bias_tensor]
            else:
                return [g_tensor]

    def get_config(self):
        config = {"data_init": self.data_init}
        base_config = super().get_config()
        return {**base_config, **config}

    def remove(self):
        kernel = tf.Variable(
            tf.nn.l2_normalize(self.v, axis=self.kernel_norm_axes) * self.g,
            name="recurrent_kernel" if self.is_rnn else "kernel",
        )

        if self.is_rnn:
            self.layer.cell.recurrent_kernel = kernel
        else:
            self.layer.kernel = kernel

        return self.layer
