# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Feature Pyramid Networks.

Feature Pyramid Networks were proposed in:
[1] Tsung-Yi Lin, Piotr Dollar, Ross Girshick, Kaiming He, Bharath Hariharan,
    , and Serge Belongie
    Feature Pyramid Networks for Object Detection. CVPR 2017.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import functools
import itertools
import logging

import tensorflow as tf

from tensorflow.python.keras import backend
from . import nn_ops
from ..ops import spatial_transform_ops
from ..utils.efficientdet_utils import get_feat_sizes, activation_fn
from xl_tensorflow.utils import hparams_config


class Fpn(object):
    """Feature pyramid networks."""

    def __init__(self,
                 min_level=3,
                 max_level=7,
                 fpn_feat_dims=256,
                 use_separable_conv=False,
                 activation='relu',
                 use_batch_norm=True,
                 norm_activation=nn_ops.norm_activation_builder(
                     activation='relu')):
        """FPN initialization function.

        Args:
          min_level: `int` minimum level in FPN output feature maps.
          max_level: `int` maximum level in FPN output feature maps.
          fpn_feat_dims: `int` number of filters in FPN layers.
          use_separable_conv: `bool`, if True use separable convolution for
            convolution in FPN layers.
          use_batch_norm: 'bool', indicating whether batchnorm layers are added.
          norm_activation: an operation that includes a normalization layer
            followed by an optional activation layer.
        """
        self._min_level = min_level
        self._max_level = max_level
        self._fpn_feat_dims = fpn_feat_dims
        if use_separable_conv:
            self._conv2d_op = functools.partial(
                tf.keras.layers.SeparableConv2D, depth_multiplier=1)
        else:
            self._conv2d_op = tf.keras.layers.Conv2D
        if activation == 'relu':
            self._activation_op = tf.nn.relu
        elif activation == 'swish':
            self._activation_op = tf.nn.swish
        else:
            raise ValueError('Unsupported activation `{}`.'.format(activation))
        self._use_batch_norm = use_batch_norm
        self._norm_activation = norm_activation

        self._norm_activations = {}
        self._lateral_conv2d_op = {}
        self._post_hoc_conv2d_op = {}
        self._coarse_conv2d_op = {}
        for level in range(self._min_level, self._max_level + 1):
            if self._use_batch_norm:
                self._norm_activations[level] = norm_activation(
                    use_activation=False, name='p%d-bn' % level)
            self._lateral_conv2d_op[level] = self._conv2d_op(
                filters=self._fpn_feat_dims,
                kernel_size=(1, 1),
                padding='same',
                name='l%d' % level)
            self._post_hoc_conv2d_op[level] = self._conv2d_op(
                filters=self._fpn_feat_dims,
                strides=(1, 1),
                kernel_size=(3, 3),
                padding='same',
                name='post_hoc_d%d' % level)
            self._coarse_conv2d_op[level] = self._conv2d_op(
                filters=self._fpn_feat_dims,
                strides=(2, 2),
                kernel_size=(3, 3),
                padding='same',
                name='p%d' % level)

    def __call__(self, multilevel_features, is_training=None):
        """Returns the FPN features for a given multilevel features.

        Args:
          multilevel_features: a `dict` containing `int` keys for continuous feature
            levels, e.g., [2, 3, 4, 5]. The values are corresponding features with
            shape [batch_size, height_l, width_l, num_filters].
          is_training: `bool` if True, the model is in training mode.

        Returns:
          a `dict` containing `int` keys for continuous feature levels
          [min_level, min_level + 1, ..., max_level]. The values are corresponding
          FPN features with shape [batch_size, height_l, width_l, fpn_feat_dims].
        """
        input_levels = list(multilevel_features.keys())
        if min(input_levels) > self._min_level:
            raise ValueError(
                'The minimum backbone level %d should be ' % (min(input_levels)) +
                'less or equal to FPN minimum level %d.:' % (self._min_level))
        backbone_max_level = min(max(input_levels), self._max_level)
        with backend.get_graph().as_default(), tf.name_scope('fpn'):
            # Adds lateral connections.
            feats_lateral = {}
            for level in range(self._min_level, backbone_max_level + 1):
                feats_lateral[level] = self._lateral_conv2d_op[level](
                    multilevel_features[level])

            # Adds top-down path.
            feats = {backbone_max_level: feats_lateral[backbone_max_level]}
            for level in range(backbone_max_level - 1, self._min_level - 1, -1):
                feats[level] = spatial_transform_ops.nearest_upsampling(
                    feats[level + 1], 2) + feats_lateral[level]

            # Adds post-hoc 3x3 convolution kernel.
            for level in range(self._min_level, backbone_max_level + 1):
                feats[level] = self._post_hoc_conv2d_op[level](feats[level])

            # Adds coarser FPN levels introduced for RetinaNet.
            for level in range(backbone_max_level + 1, self._max_level + 1):
                feats_in = feats[level - 1]
                if level > backbone_max_level + 1:
                    feats_in = self._activation_op(feats_in)
                feats[level] = self._coarse_conv2d_op[level](feats_in)
            if self._use_batch_norm:
                # Adds batch_norm layer.
                for level in range(self._min_level, self._max_level + 1):
                    feats[level] = self._norm_activations[level](
                        feats[level], is_training=is_training)
        return feats


class BiFpn(object):
    """BiFeature pyramid networks."""

    def __init__(self,
                 params,
                 min_level=3,
                 max_level=7,
                 use_separable_conv=False,
                 activation='relu',
                 use_batch_norm=True,
                 output_size=(640, 640),
                 norm_activation=nn_ops.norm_activation_builder(
                     activation='relu')):
        """FPN initialization function.

        Args:
          min_level: `int` minimum level in FPN output feature maps.
          max_level: `int` maximum level in FPN output feature maps.
          fpn_feat_dims: `int` number of filters in FPN layers.
          use_separable_conv: `bool`, if True use separable convolution for
            convolution in FPN layers.
          use_batch_norm: 'bool', indicating whether batchnorm layers are added.
          norm_activation: an operation that includes a normalization layer
            followed by an optional activation layer.
        """
        self._min_level = min_level
        self._max_level = max_level
        self._output_size = output_size
        if use_separable_conv:
            self._conv2d_op = functools.partial(
                tf.keras.layers.SeparableConv2D, depth_multiplier=1)
        else:
            self._conv2d_op = tf.keras.layers.Conv2D
        if activation == 'relu':
            self._activation_op = tf.nn.relu
        elif activation == 'swish':
            self._activation_op = tf.nn.swish
        else:
            raise ValueError('Unsupported activation `{}`.'.format(activation))
        self._use_batch_norm = use_batch_norm
        self._norm_activation = norm_activation

        self._norm_activations = {}
        self._lateral_conv2d_op = {}
        self._post_hoc_conv2d_op = {}
        self._coarse_conv2d_op = {}
        for level in range(self._min_level, self._max_level + 1):
            if self._use_batch_norm:
                self._norm_activations[level] = norm_activation(
                    use_activation=False, name='p%d-bn' % level)
            self._lateral_conv2d_op[level] = self._conv2d_op(
                filters=params.fpn.fpn_feat_dims,
                kernel_size=(1, 1),
                padding='same',
                name='l%d' % level)
            self._post_hoc_conv2d_op[level] = self._conv2d_op(
                filters=params.fpn.fpn_feat_dims,
                strides=(1, 1),
                kernel_size=(3, 3),
                padding='same',
                name='post_hoc_d%d' % level)
            self._coarse_conv2d_op[level] = self._conv2d_op(
                filters=params.fpn.fpn_feat_dims,
                strides=(2, 2),
                kernel_size=(3, 3),
                padding='same',
                name='p%d' % level)

    def get_fpn_config(self, fpn_name, min_level, max_level, weight_method):
        """Get fpn related configuration."""
        if not fpn_name:
            fpn_name = 'bifpn_fa'
        name_to_config = {
            'bifpn_sum': self.bifpn_sum_config(),
            'bifpn_fa': self.bifpn_fa_config(),
            'bifpn_dyn': self.bifpn_dynamic_config(min_level, max_level, weight_method)
        }
        return name_to_config[fpn_name]

    def fuse_features(self, nodes, weight_method):
        """Fuse features from different resolutions and return a weighted sum.

        Args:
          nodes: a list of tensorflow features at different levels
          weight_method: feature fusion method. One of:
            - "attn" - Softmax weighted fusion
            - "fastattn" - Fast normalzied feature fusion
            - "sum" - a sum of inputs

        Returns:
          A tensor denoting the fused feature.
        """
        dtype = nodes[0].dtype

        if weight_method == 'attn':
            edge_weights = [tf.cast(tf.Variable(1.0, name='WSM'), dtype=dtype)
                            for _ in nodes]
            normalized_weights = tf.nn.softmax(tf.stack(edge_weights))
            nodes = tf.stack(nodes, axis=-1)
            new_node = tf.reduce_sum(nodes * normalized_weights, -1)
        elif weight_method == 'fastattn':
            edge_weights = [
                tf.nn.relu(tf.cast(tf.Variable(1.0, name='WSM'), dtype=dtype))
                for _ in nodes
            ]
            weights_sum = tf.add_n(edge_weights)
            nodes = [nodes[i] * edge_weights[i] / (weights_sum + 0.0001)
                     for i in range(len(nodes))]
            new_node = tf.add_n(nodes)
        elif weight_method == 'sum':
            new_node = tf.add_n(nodes)
        else:
            raise ValueError(
                'unknown weight_method {}'.format(weight_method))

        return new_node

    def build_bifpn_layer(self, feats, feat_sizes, params):
        """Builds a feature pyramid given previous feature pyramid and config."""
        p = params  # use p to denote the network config.
        if p.fpn.fpn_config:
            fpn_config = p.fpn_config
        else:
            fpn_config = self.get_fpn_config(p.fpn.fpn_name, p.fpn.architecture.min_level, p.fpn.architecture.max_level,
                                             p.fpn.fpn_weight_method)

        num_output_connections = [0 for _ in feats]
        for i, fnode in enumerate(fpn_config.nodes):
            with tf.name_scope('fnode{}'.format(i)):
                logging.info('fnode %d : %s', i, fnode)
                new_node_height = feat_sizes[fnode['feat_level']]['height']
                new_node_width = feat_sizes[fnode['feat_level']]['width']
                nodes = []
                for idx, input_offset in enumerate(fnode['inputs_offsets']):
                    input_node = feats[input_offset]
                    num_output_connections[input_offset] += 1
                    input_node = spatial_transform_ops.resample_feature_map(
                        input_node, '{}_{}_{}'.format(idx, input_offset, len(feats)),
                        new_node_height, new_node_width, p.fpn.fpn_feat_dims,
                        p.fpn.apply_bn_for_resampling, p.is_training_bn,
                        p.fpn.conv_after_downsample,
                        p.fpn.use_native_resize_op,
                        p.fpn.pooling_type,
                        use_tpu=p.use_tpu,
                        data_format=params.data_format)
                    nodes.append(input_node)

                new_node = self.fuse_features(nodes, fpn_config.weight_method)

                with tf.name_scope('op_after_combine{}'.format(len(feats))):
                    if not p.fpn.conv_bn_act_pattern:
                        new_node = activation_fn(new_node, p.act_type)

                    if p.fpn.use_separable_conv:
                        conv_op = functools.partial(
                            tf.keras.layers.SeparableConv2D, depth_multiplier=1)
                    else:
                        conv_op = tf.keras.layers.Conv2D

                    new_node = conv_op(
                        filters=p.fpn.fpn_feat_dims,
                        kernel_size=(3, 3),
                        padding='same',
                        use_bias=True if not p.fpn.conv_bn_act_pattern else False,
                        data_format=params.data_format,
                        name='conv')(new_node)

                    new_node = tf.keras.layers.BatchNormalization(
                        is_training_bn=p.is_training_bn,
                        act_type=None if not p.fpn.conv_bn_act_pattern else p.act_type,
                        data_format=params.data_format,
                        use_tpu=p.use_tpu,
                        name='bn')(new_node)

                feats.append(new_node)
                num_output_connections.append(0)

        output_feats = {}
        for l in range(p.architecture.min_level, p.architecture.max_level + 1):
            for i, fnode in enumerate(reversed(fpn_config.nodes)):
                if fnode['feat_level'] == l:
                    output_feats[l] = feats[-1 - i]
                    break
        return output_feats

    def bifpn_sum_config(self):
        """BiFPN config with sum."""
        p = hparams_config.Config()
        p.nodes = [
            {'feat_level': 6, 'inputs_offsets': [3, 4]},
            {'feat_level': 5, 'inputs_offsets': [2, 5]},
            {'feat_level': 4, 'inputs_offsets': [1, 6]},
            {'feat_level': 3, 'inputs_offsets': [0, 7]},
            {'feat_level': 4, 'inputs_offsets': [1, 7, 8]},
            {'feat_level': 5, 'inputs_offsets': [2, 6, 9]},
            {'feat_level': 6, 'inputs_offsets': [3, 5, 10]},
            {'feat_level': 7, 'inputs_offsets': [4, 11]},
        ]
        p.weight_method = 'sum'
        return p

    def bifpn_fa_config(self):
        """BiFPN config with fast weighted sum."""
        p = self.bifpn_sum_config()
        p.weight_method = 'fastattn'
        return p

    def bifpn_dynamic_config(self, min_level, max_level, weight_method):
        """A dynamic bifpn config that can adapt to different min/max levels."""
        p = hparams_config.Config()
        p.weight_method = weight_method or 'fastattn'

        # Node id starts from the input features and monotonically increase whenever
        # a new node is added. Here is an example for level P3 - P7:
        #     P7 (4)              P7" (12)
        #     P6 (3)    P6' (5)   P6" (11)
        #     P5 (2)    P5' (6)   P5" (10)
        #     P4 (1)    P4' (7)   P4" (9)
        #     P3 (0)              P3" (8)
        # So output would be like:
        # [
        #   {'feat_level': 6, 'inputs_offsets': [3, 4]},  # for P6'
        #   {'feat_level': 5, 'inputs_offsets': [2, 5]},  # for P5'
        #   {'feat_level': 4, 'inputs_offsets': [1, 6]},  # for P4'
        #   {'feat_level': 3, 'inputs_offsets': [0, 7]},  # for P3"
        #   {'feat_level': 4, 'inputs_offsets': [1, 7, 8]},  # for P4"
        #   {'feat_level': 5, 'inputs_offsets': [2, 6, 9]},  # for P5"
        #   {'feat_level': 6, 'inputs_offsets': [3, 5, 10]},  # for P6"
        #   {'feat_level': 7, 'inputs_offsets': [4, 11]},  # for P7"
        # ]
        num_levels = max_level - min_level + 1
        node_ids = {min_level + i: [i] for i in range(num_levels)}

        level_last_id = lambda level: node_ids[level][-1]
        level_all_ids = lambda level: node_ids[level]
        id_cnt = itertools.count(num_levels)

        p.nodes = []
        for i in range(max_level - 1, min_level - 1, -1):
            # top-down path.
            p.nodes.append({
                'feat_level': i,
                'inputs_offsets': [level_last_id(i), level_last_id(i + 1)]
            })
            node_ids[i].append(next(id_cnt))

        for i in range(min_level + 1, max_level + 1):
            # bottom-up path.
            p.nodes.append({
                'feat_level': i,
                'inputs_offsets': level_all_ids(i) + [level_last_id(i - 1)]
            })
            node_ids[i].append(next(id_cnt))

        return p

    def __call__(self, multilevel_features, params, is_training=None):
        """Returns the FPN features for a given multilevel features.

        Args:
          multilevel_features: a `dict` containing `int` keys for continuous feature
            levels, e.g., [2, 3, 4, 5]. The values are corresponding features with
            shape [batch_size, height_l, width_l, num_filters].
          is_training: `bool` if True, the model is in training mode.

        Returns:
          a `dict` containing `int` keys for continuous feature levels
          [min_level, min_level + 1, ..., max_level]. The values are corresponding
          FPN features with shape [batch_size, height_l, width_l, fpn_feat_dims].
        """
        # step 1: Build additional input features that are not from backbone.(ie. level 6 and 7)
        feats = []
        with backend.get_graph().as_default(), tf.name_scope('bifpn'):
            for level in range(self._min_level, self._max_level + 1):
                if level in multilevel_features.keys():
                    feats.append(multilevel_features[level])
                else:
                    h_id, w_id = (1, 2)  # 不允许通道前置
                    feats.append(
                        spatial_transform_ops.resample_feature_map(
                            feats[-1],
                            name='p%d' % level,
                            target_height=(feats[-1].shape[h_id] - 1) // 2 + 1,
                            target_width=(feats[-1].shape[w_id] - 1) // 2 + 1,
                            target_num_channels=params.fpn.fpn_feat_dims,
                            apply_bn=params.fpn.apply_bn_for_resampling,
                            is_training=params.is_training_bn,
                            conv_after_downsample=params.fpn.conv_after_downsample,
                            use_native_resize_op=params.fpn.use_native_resize_op,
                            pooling_type=params.fpn.pooling_type,
                            use_tpu=False,
                            data_format="channels_last"
                        ))
            feat_sizes = get_feat_sizes(self._output_size[0], self._max_level)
            # todo 尺寸校验暂时搁置        _verify_feats_size

            with tf.name_scope("bifpn_cells"):
                for rep in range(params.fpn.fpn_cell_repeats):
                    logging.info('building cell %d', rep)
                    new_feats = self.build_bifpn_layer(feats, feat_sizes, params)
                    # todo 尺寸校验暂时搁置        _verify_feats_size

            return new_feats
