"""
Version : 1.0 (06-09-2022).

Author  : Mbaye DIONGUE

Copyright (C) 2022

This file is part of the codes provided at http://proximity-operator.net

By downloading and/or using any of these files, you implicitly agree to
all the terms of the license CeCill-B (available online).
"""

from typing import Union
import numpy as np


class Log:
    r"""Compute the proximity operator and the evaluation of gamma*f.

    Where f is defined as:

                      /  - log(x)          if  x> 0
               f(x)= |
                     \    +INF              otherwise


    'gamma' is the scale factor

    When the input 'x' is an array, the output is computed element-wise :

    -When calling the function (and not the proximity operator) the result
    is computed element-wise SUM. So the command >>>Log()(x) will
    return a scalar even if x is a vector.

    - But for the proximity operator (method 'prox'), the output has the same
    shape as the input 'x'. So, the command >>>Log().prox(x)   will
    return an array with the same shape as 'x'

     INPUTS
    ========
     x     - scalar or ND array
     gamma - positive, scalar or ND array with the same size as 'x' [default: gamma=1]

    =======
    Examples
    ========

     Evaluate the function  f:

     >>> Log()( np.e )
     -1.0

      Compute the resuslt as an element-wise sum when the input is a vector input:

     >>> Log()( [1, 3, np.e] )
      -2.09861228866811

     Compute the proximity operator at a given point :

     >>> Log().prox(  [-2, 3, 4 ])
     array([0.41421356, 3.30277564, 4.23606798])

     Use a scale factor 'gamma'>0 to compute the proximity operator of  the function
     'gamma*f'

     >>> Log().prox( [-2, 3, 4, np.e ], gamma=2.5)
     array([0.87082869, 3.67944947, 4.54950976, 3.44415027])
    """

    def __init__(self):
        pass

    def prox(self, x: np.ndarray, gamma: Union[float, np.ndarray] = 1) -> np.ndarray:
        if np.size(x) > 1 and (not isinstance(x, np.ndarray)):
            x = np.array(x)
        if np.size(gamma) > 1 and (not isinstance(gamma, np.ndarray)):
            gamma = np.array(gamma)
        self._check(x, gamma)
        return 0.5 * (x + np.sqrt(x**2 + 4 * gamma))

    def __call__(self, x: np.ndarray) -> float:
        if np.size(x) > 1 and (not isinstance(x, np.ndarray)):
            x = np.array(x)
        if np.size(x) <= 1:
            x = np.reshape(x, (-1))
        result = np.zeros(np.shape(x))
        mask = x > 0
        result[mask] = - np.log(x[mask])
        result[x <= 0] = np.inf
        return np.sum(result)

    def _check(self, x, gamma):
        if np.any(gamma <= 0):
            raise ValueError(
                "'gamma' (or all of its components if it is an array)"
                + " must be strictly positive"
            )
        if (np.size(gamma) > 1) and (np.size(gamma) != np.size(x)):
            raise ValueError("gamma' must be either scalar or the same size as 'x'")
