# -*- coding: utf-8 -*-
# Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License version 3 as published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

# Contributors:
#    INITIAL AUTHORS - initial API and implementation and/or initial
#                         documentation
#        :author: Matthias De Lozzo, Syver Doving Agdestein
#    OTHER AUTHORS   - MACROSCOPIC CHANGES
"""This module contains the base class for classification algorithms.

The :mod:`~gemseo.mlearning.classification.classification` module
implements classification algorithms,
whose goal is to assess the membership of input data to classes.

Classification algorithms provide methods for predicting classes of new input data,
as well as predicting the probabilities of belonging to each of the classes
wherever possible.

This concept is implemented through the :class:`.MLClassificationAlgo` class
which inherits from the :class:`.MLSupervisedAlgo` class.
"""
from __future__ import division, unicode_literals

from typing import Dict, Iterable, List, Optional, Sequence, Union

from numpy import ndarray, unique, zeros

from gemseo.core.dataset import Dataset
from gemseo.mlearning.core.ml_algo import DataType, MLAlgoParameterType, TransformerType
from gemseo.mlearning.core.supervised import MLSupervisedAlgo
from gemseo.mlearning.core.supervised import (
    SavedObjectType as MLSupervisedAlgoSavedObjectType,
)

SavedObjectType = Union[
    MLSupervisedAlgoSavedObjectType, Sequence[str], Dict[str, ndarray], int
]


class MLClassificationAlgo(MLSupervisedAlgo):
    """Classification Algorithm.

    Inheriting classes shall implement the :meth:`!MLSupervisedAlgo._fit` and
    :meth:`!MLClassificationAlgo._predict` methods, and
    :meth:`!MLClassificationAlgo._predict_proba_soft` method if possible.

    Attributes:
        n_classes (int): The number of classes.
    """

    def __init__(
        self,
        data,  # type: Dataset
        transformer=MLSupervisedAlgo.DEFAULT_TRANSFORMER,  # type: TransformerType
        input_names=None,  # type: Optional[Iterable[str]]
        output_names=None,  # type: Optional[Iterable[str]]
        **parameters  # type: MLAlgoParameterType
    ):  # type: (...) -> None
        super(MLClassificationAlgo, self).__init__(
            data,
            transformer=transformer,
            input_names=input_names,
            output_names=output_names,
            **parameters
        )
        self.n_classes = None

    def learn(
        self,
        samples=None,  # type: Optional[List[int]]
    ):  # type: (...) -> None
        output_data = self.learning_set.get_data_by_names(self.output_names, False)
        self.n_classes = unique(output_data).shape[0]
        super(MLClassificationAlgo, self).learn(samples)

    @MLSupervisedAlgo.DataFormatters.format_input_output
    def predict_proba(
        self,
        input_data,  # type: DataType
        hard=True,  # type: bool
    ):  # type: (...)-> ndarray
        """Predict the probability of belonging to each cluster from input data.

        The user can specified these input data either as a numpy array,
        e.g. :code:`array([1., 2., 3.])`
        or as a dictionary,
        e.g.  :code:`{'a': array([1.]), 'b': array([2., 3.])}`.

        If the numpy arrays are of dimension 2,
        their i-th rows represent the input data of the i-th sample;
        while if the numpy arrays are of dimension 1,
        there is a single sample.

        The type of the output data and the dimension of the output arrays
        will be consistent
        with the type of the input data and the size of the input arrays.

        Args:
            input_data: The input data.
            hard: Whether clustering should be hard (True) or soft (False).

        Returns:
            The probability of belonging to each cluster.
        """
        return self._predict_proba(input_data, hard)

    def _predict_proba(
        self,
        input_data,  # type: ndarray
        hard=True,  # type: bool
    ):  # type: (...)-> ndarray
        """Predict the probability of belonging to each class.

        Args:
            input_data: The input data with shape (n_samples, n_inputs).
            hard: Whether clustering should be hard (True) or soft (False).

        Returns:
            The probability of belonging to each class
                with shape (n_samples, n_classes).
        """
        if hard:
            probas = self._predict_proba_hard(input_data)
        else:
            probas = self._predict_proba_soft(input_data)
        return probas

    def _predict_proba_hard(
        self,
        input_data,  # type: ndarray
    ):  # type: (...)-> ndarray
        """Return 1 if the data belongs to a class, 0 otherwise.

        Args:
            input_data: The input data with shape (n_samples, n_inputs).

        Returns:
            The indicator of belonging to each class with shape (n_samples, n_classes).
        """
        n_samples = input_data.shape[0]
        prediction = self._predict(input_data).astype(int)
        n_outputs = prediction.shape[1]
        probas = zeros((n_samples, self.n_classes, n_outputs))
        for n_sample in range(prediction.shape[0]):
            for n_output in range(n_outputs):
                probas[n_sample, prediction[n_sample, n_output], n_output] = 1
        return probas

    def _predict_proba_soft(
        self,
        input_data,  # type: ndarray
    ):  # type: (...)-> ndarray
        """Predict the probability of belonging to each class.

        Args:
            input_data: The input data with shape (n_samples, n_inputs).

        Returns:
            The probability of belonging to each class
                with shape (n_samples, n_classes).
        """
        raise NotImplementedError

    def _get_objects_to_save(self):  # type: (...) -> SavedObjectType
        objects = super(MLClassificationAlgo, self)._get_objects_to_save()
        objects["n_classes"] = self.n_classes
        return objects
