# Copyright 2021 Carl Zeiss Microscopy GmbH

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#    http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Provides postprocessing utilities."""
from typing import Any  # noqa # pylint: disable=unused-import
from typing import Optional, Sequence, Union
import collections.abc

import tensorflow as tf


class SigmoidToSoftmaxScores(tf.keras.layers.Layer):
    """A Keras layer for converting sigmoidal output to softmax scores."""

    @tf.function
    def call(self, inputs: "tf.Tensor", **kwargs: "Any") -> "tf.Tensor":
        """Performs the conversion from sigmoidal outputs to softmax scores in the last dimension.

        Args:
            inputs: An output from a layer with sigmoid activation. The last dimension must be of size 1.
            **kwargs: Additional keyword arguments.

        Returns:
            The probability distribution as generated by a softmax layer.

        Raises:
            ValueError: If the last dimension of the input does not have exactly size 1.
        """
        # pylint: disable=no-self-use
        if inputs.shape[-1] != 1:
            raise ValueError(
                "The shape of the sigmoid activated input must be 1 on the last dimension. "
                "Received Tensor of shape {}".format(inputs.shape)
            )
        return tf.concat([inputs, tf.constant(1.0)-inputs], -1)


def add_postprocessing_layers(
    model: "tf.keras.Model",
    layers: Optional[
        Union["tf.keras.layers.Layer", Sequence["tf.keras.layers.Layer"]]
    ],
) -> "tf.keras.Model":
    """Appends a given post-processing layer to a given Keras model.

    Args:
        model: The Keras model to be wrapped.
        layers: The layers to be appended.

    Returns:
        A new Keras model wrapping the provided Keras model and the post-processing layers.
    """
    # Handle single layer and None input
    if layers is None:
        layers = []
    elif not isinstance(layers, collections.abc.Sequence):
        layers = [layers]

    # Apply model
    inputs = model.inputs
    outputs = model(inputs)

    for layer in layers:
        outputs = layer(outputs)

    # Return new Keras model
    return tf.keras.Model(inputs=inputs, outputs=outputs)
