# Q-MM: A Python Quadratic Majorization Minimization toolbox
# Copyright (C) 2021 François Orieux <francois.orieux@universite-paris-saclay.fr>

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

"""The ``qmm`` module
==================

This module implements Quadratic Majorize-Minimize optimization algorithms.

"""

# pylint: disable=bad-continuation

import abc
import itertools as it
import time
from functools import reduce
from operator import iadd
from typing import Callable, List, Sequence, Tuple, Union

import numpy as np  # type: ignore
import numpy.linalg as la  # type: ignore
from numpy import ndarray as array

ArrOrSeq = Union[array, Sequence[array]]


class OptimizeResult(dict):
    """Represents the optimization result.

    x: array
        The solution of the optimization, with same shape than `init`.
    success: bool
        Whether or not the optimizer exited successfully.
    status: int
        Termination status of the optimizer. Its value depends on the underlying
        solver. Refer to message for details.
    message: str
        Description of the cause of the termination.
    nit: int
        Number of iterations performed by the optimizer.
    grad_norm: list of float
        The gradient norm at each iteration
    diff: list of float
        The value of ||x^(k+1) - x^(k)||² at each iteration
    time: list of float
        The time at each iteration, starting at 0, in seconds.
    fun: float
        The value of the objective function.
    jac: array
        The gradient of the objective function.
    objv_val: list of float
        The objective value at each iteration

    Notes
    -----
    :class:`OptimizeResult` mime `OptimizeResult` of scipy for compatibility.

    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.maxcv = 0
        self.nfev = 0
        self.nhev = 0
        self.jac = None
        self.jav = None
        self.hess = None
        self.hess_inv = None
        self.success = False
        self.status = 99
        self.message = "Not applicable"
        self.njev = 0
        self.nit = 0
        self.grad_norm = []
        self.diff = []
        self.time = []
        self.objv_val = []
        self.x = None

    @property
    def fun(self):
        return self.objv_val[-1]

    def __getattr__(self, name):
        if name == "fun":
            return self["objv_val"][-1]
        if name in self:
            return self[name]
        raise AttributeError("No such attribute: " + name)

    def __setattr__(self, name, value):
        self[name] = value

    def __delattr__(self, name):
        if name in self:
            del self[name]
        else:
            raise AttributeError("No such attribute: " + name)


def mmmg(
    objv_list: Sequence["BaseObjective"],
    init: array,
    tol: float = 1e-4,
    max_iter: int = 500,
    callback: Callable[[OptimizeResult], None] = None,
    calc_objv: bool = False,
) -> OptimizeResult:
    r"""The Majorize-Minimize Memory Gradient (`3mg`) algorithm.

    The `mmmg` (`3mg`) algorithm is a subspace memory-gradient optimization
    algorithm with an explicit step formula based on Majorize-Minimize Quadratic
    approach [2]_.

    Parameters
    ----------
    objv_list : list of `BaseObjective`
        A list of :class:`BaseObjective` objects that each represent a `μ ψ(V·x - ω)`.
        The objectives are implicitly summed.
    init : array
        The initial point.
    tol : float, optional
        The stopping tolerance. The algorithm is stopped when the gradient norm
        is inferior to `init.size * tol`.
    max_iter : int, optional
        The maximum number of iterations.
    callback : callable, optional
        A function that receive the `OptimizeResult` at the end of each
        iteration.
    calc_objv: boolean, optional
        If True, objective function is computed at each iteration with low
        overhead. False by default. Not used by the algorithm.

    Returns
    -------
    result : OptimizeResult

    References
    ----------
    .. [2] E. Chouzenoux, J. Idier, and S. Moussaoui, “A Majorize-Minimize
       Strategy for Subspace Optimization Applied to Image Restoration,” IEEE
       Trans. on Image Process., vol. 20, no. 6, pp. 1517–1528, Jun. 2011, doi:
       10.1109/TIP.2010.2103083.
    """
    res = OptimizeResult()
    previous_flag = []
    for objv in objv_list:
        previous_flag.append(objv.calc_objv)
        objv.calc_objv = calc_objv

    res["x"] = init.copy().reshape((-1, 1))

    # The first previous moves are initialized with 0 array. Consequently, the
    # first iterations implementation can be improved, at the cost of if
    # statement.
    move = np.zeros_like(res["x"])
    op_directions = [
        np.tile(_vect(objv.operator, move, init.shape), 2) for objv in objv_list
    ]
    step = np.ones((2, 1))

    res["time"].append(time.time())

    for iteration in range(max_iter):
        # Vectorized gradient
        grad = _gradient(objv_list, res["x"], init.shape)
        res["grad_norm"].append(la.norm(grad))
        res["jac"] = grad.reshape(init.shape)
        res["objv_val"].append(_lastv(objv_list))

        # Stopping test
        if res["grad_norm"][-1] < init.size * tol:
            res["success"] = True
            res["status"] = 0
            break

        # Memory gradient directions
        directions = np.c_[-grad, move]

        # Step by Majorize-Minimize
        op_directions = [
            np.c_[_vect(objv.operator, grad, init.shape), i_op_dir @ step]
            for objv, i_op_dir in zip(objv_list, op_directions)
        ]
        step = -la.pinv(
            sum(
                objv.norm_mat_major(i_op_dir, res["x"].reshape(init.shape))
                for objv, i_op_dir in zip(objv_list, op_directions)
            )
        ) @ (directions.T @ grad)
        move = directions @ step

        # update
        res["x"] += move

        res["diff"].append(np.sum(move) ** 2)
        res["time"].append(time.time())

        if callback is not None:
            callback(res)

    if res.status == 0:
        res["message"] = "Stopping conditions reached"
    else:
        res["success"] = False
        res["status"] = 1
        res["message"] = "Maximum number of iteration reached"
    res["x"] = res["x"].reshape(init.shape)
    res["njev"] = iteration + 1
    res["nit"] = iteration + 1
    res["time"] = list(np.asarray(res.time) - res.time[0])

    for objv, flag in zip(objv_list, previous_flag):
        objv.calc_objv = flag

    return res


def mmcg(
    objv_list: Sequence["BaseObjective"],
    init: array,
    precond: Callable[[array], array] = None,
    tol: float = 1e-4,
    max_iter: int = 500,
    callback: Callable[[OptimizeResult], None] = None,
    calc_objv: bool = False,
) -> OptimizeResult:
    """The Majorize-Minimize Conjugate Gradient (MM-CG) algorithm.

    The MM-CG is a nonlinear conjugate gradient (NL-CG) optimization algorithm
    with an explicit step formula based on Majorize-Minimize Quadratic approach
    [1]_.

    Parameters
    ----------
    objv_list : list of `BaseObjective`
        A list of :class:`BaseObjective` objects that each represent a `μ ψ(V·x - ω)`.
        The objectives are implicitly summed.
    init : ndarray
        The initial point.
    precond : callable, optional
        A callable that must implement a preconditioner, that is `M⁻¹·x`. Must
        be a callable with a unique input parameter `x` and unique output.
    tol : float, optional
        The stopping tolerance. The algorithm is stopped when the gradient norm
        is inferior to `init.size * tol`.
    max_iter : int, optional
        The maximum number of iterations.
    callback : callable, optional
        A function that receive the `OptimizeResult` at the end of each
        iteration.
    calc_objv: boolean, optional
        If True, objective function is computed at each iteration with low
        overhead. False by default. Not used by the algorithm.

    Returns
    -------
    result : OptimizeResult

    References
    ----------
    .. [1] C. Labat and J. Idier, “Convergence of Conjugate Gradient Methods
       with a Closed-Form Stepsize Formula,” J Optim Theory Appl, p. 18, 2008.
    """
    if precond is None:
        precond = lambda x: x
    res = OptimizeResult()
    previous_flag = []
    for objv in objv_list:
        previous_flag.append(objv.calc_objv)
        objv.calc_objv = calc_objv

    res["x"] = init.copy().reshape((-1, 1))

    residual = -_gradient(objv_list, res["x"], init.shape)
    sec = _vect(precond, residual, init.shape)
    direction = sec
    delta = residual.T @ direction

    res["time"].append(time.time())

    for iteration in range(max_iter):
        # Stop test
        res["grad_norm"].append(la.norm(residual))
        if res["grad_norm"][-1] < init.size * tol:
            break

        # update
        op_direction = [
            _vect(objv.operator, direction, init.shape) for objv in objv_list
        ]

        step = direction.T @ residual
        step = step / sum(
            objv.norm_mat_major(i_op_dir, res["x"].reshape(init.shape))
            for objv, i_op_dir in zip(objv_list, op_direction)
        )

        res["x"] += step * direction

        res["diff"].append(np.sum(step * direction) ** 2)
        res["time"].append(time.time())

        # Gradient
        residual = -_gradient(objv_list, res["x"], init.shape)
        res["jac"] = -residual.reshape(init.shape)
        res["objv_val"].append(_lastv(objv_list))

        # Conjugate direction. No reset is done, see Shewchuck.
        delta_old = delta
        delta_mid = residual.T @ sec
        sec = _vect(precond, residual, init.shape)
        delta = residual.T @ sec
        if (delta - delta_mid) / delta_old >= 0:
            direction = sec + (delta - delta_mid) / delta_old * direction
        else:
            direction = sec

        if callback is not None:
            callback(res)

    if res.status == 0:
        res["message"] = "Stopping conditions reached"
    else:
        res["success"] = False
        res["status"] = 1
        res["message"] = "Maximum number of iteration reached"
    res["x"] = res["x"].reshape(init.shape)
    res["njev"] = iteration + 1
    res["nit"] = iteration + 1
    res["time"] = list(np.asarray(res.time) - res.time[0])

    for objv, flag in zip(objv_list, previous_flag):
        objv.calc_objv = flag

    return res


def lcg(
    objv_list: Sequence["QuadObjective"],
    init: array,
    precond: Callable[[array], array] = None,
    tol: float = 1e-4,
    max_iter: int = 500,
    callback: Callable[[OptimizeResult], None] = None,
    calc_objv: bool = False,
) -> OptimizeResult:
    """Linear Conjugate Gradient (CG) algorithm.

    Linear Conjugate Gradient optimization algorithm for quadratic objective.

    Parameters
    ----------
    objv_list : list of `QuadObjective`
        A list of :class:`QuadObjective` objects that each represent a `½ μ
        ||V·x - ω||²`. The objectives are implicitly summed.
    init : ndarray
        The initial point.
    precond : callable, optional
        A callable that must implement a preconditioner, that is `M⁻¹·x`. Must
        be a callable with a unique input parameter `x` and unique output.
    tol : float, optional
        The stopping tolerance. The algorithm is stopped when the gradient norm
        is inferior to `init.size * tol`.
    max_iter : int, optional
        The maximum number of iterations.
    callback : callable, optional
        A function that receive the `OptimizeResult` at the end of each
        iteration.
    calc_objv: boolean, optional
        If True, objective function is computed at each iteration with low
        overhead. False by default. Not used by the algorithm.

    Returns
    -------
    result : OptimizeResult

    """

    if precond is None:
        precond = lambda x: x
    res = OptimizeResult()
    previous_flag = []
    for objv in objv_list:
        previous_flag.append(objv.calc_objv)
        objv.calc_objv = calc_objv

    res["x"] = init.copy().reshape((-1, 1))

    second_term = np.reshape(reduce(iadd, (c.data_t for c in objv_list)), (-1, 1))

    def hessian(arr):
        return reduce(iadd, (_vect(c.hessp, arr, init.shape) for c in objv_list))

    # Gradient at current init
    residual = second_term - hessian(res["x"])
    direction = _vect(precond, residual, init.shape)

    res["grad_norm"].append(np.sum(np.real(np.conj(residual) * direction)))
    res["time"].append(time.time())

    for iteration in range(max_iter):
        hess_dir = hessian(direction)
        # s = rᵀr / dᵀAd
        # Optimal step
        step = res.grad_norm[-1] / np.sum(np.real(np.conj(direction) * hess_dir))

        # Descent x^(i+1) = x^(i) + s*d
        res["x"] += step * direction

        # r^(i+1) = r^(i) - s * A·d
        if iteration % 50 == 0:
            residual = second_term - hessian(res["x"])
        else:
            residual -= step * hess_dir
        res["jac"] = -residual.reshape(init.shape)
        res["objv_val"].append(_lastv(objv_list))

        # Conjugate direction with preconditionner
        secant = _vect(precond, residual, init.shape)
        res["grad_norm"].append(np.sum(np.real(np.conj(residual) * secant)))
        direction = secant + (res["grad_norm"][-1] / res["grad_norm"][-2]) * direction

        res["diff"].append(np.sum(step * direction) ** 2)
        res["time"].append(time.time())

        # Stopping condition
        if np.sqrt(res.grad_norm[-1]) < init.size * tol:
            res["success"] = True
            res["status"] = 0
            break

        if callback is not None:
            callback(res)

    if res.status == 0:
        res["message"] = "Stopping conditions reached"
    else:
        res["success"] = False
        res["status"] = 1
        res["message"] = "Maximum number of iteration reached"
    res["x"] = res.x.reshape(init.shape)
    res["njev"] = iteration + 1
    res["nit"] = iteration + 1
    res["grad_norm"] = list(np.sqrt(res.grad_norm))
    res["time"] = list(np.asarray(res.time) - res.time[0])

    for objv, flag in zip(objv_list, previous_flag):
        objv.calc_objv = flag

    return res


# Vectorized call
def _vect(func: Callable[[array], array], point: array, shape: Tuple) -> array:
    """Call func with point reshaped as shape and return vectorized output"""
    return np.reshape(func(np.reshape(point, shape)), (-1, 1))


# Vectorized gradient
def _gradient(
    objv_list: Sequence["BaseObjective"], point: array, shape: Tuple
) -> array:
    """Compute sum of gradient with vectorized parameters and return"""
    # The use of reduce and iadd do an more efficient numpy inplace sum
    return reduce(iadd, (_vect(c.gradient, point, shape) for c in objv_list))


def _lastv(objv_list: Sequence["BaseObjective"]):
    """Return the value of objective computed after gradient evaluation"""
    return sum(getattr(objv, "lastv") for objv in objv_list)


class BaseObjective(abc.ABC):
    r"""An abstract base class for objective function

    .. math::
        J(x) = \mu \Psi \left(V x - \omega \right)

    with :math:`\Psi(u) = \sum_i \varphi(u_i)`.
    """

    def __init__(self):
        self._lastv = -1
        self.calc_objv = False

    @property
    def lastv(self):
        """Return the value of objective after gradient computation."""
        return self._lastv

    @lastv.setter
    def lastv(self, val):
        """Return the value of objective after gradient computation."""
        self._lastv = val

    @abc.abstractmethod
    def operator(self, point: array) -> array:
        """Compute the output of `V·x`."""
        return NotImplemented

    @abc.abstractmethod
    def gradient(self, point: array) -> array:
        """Compute the gradient at current point."""
        return NotImplemented

    @abc.abstractmethod
    def norm_mat_major(self, vecs: array, point: array) -> array:
        """Return the normal matrix of the quadratic major function.

        Given vectors `W = V·S`, return `Wᵀ·diag(b)·W`

        where S are the vectors defining a subspace and `b` are Geman &
        Reynolds coefficients at given `point`.

        Parameters
        ----------
        vecs : array
            The `W` vectors.
        point : array
            The given point where to compute Geman & Reynolds coefficients `b`.

        Returns
        -------
        out : array
            The normal matrix
        """
        return NotImplemented


class Objective(BaseObjective):
    r"""An objective function defined as

    .. math::
        J(x) = \mu \Psi \left(V x - \omega \right)

    with :math:`\Psi(u) = \sum_i \varphi(u_i)`.


    data : array
        The `data` array, or the vectorized list of array given at init.
    hyper : float
        The hyperparameter value `μ`.
    loss : Loss
        The loss `φ`.
    """

    def __init__(  # pylint: disable=too-many-arguments
        self,
        operator: Callable[[array], ArrOrSeq],
        adjoint: Callable[[ArrOrSeq], array],
        loss: "Loss",
        data: ArrOrSeq = None,
        hyper: float = 1,
    ):
        """A objective function `μ ψ(V·x - ω)`.

        Parameters
        ----------
        operator: callable
            A callable that compute the output `V·x`.
        adjoint: callable
            A callable that compute `Vᵀ·e`.
        loss: Loss
            The loss `φ`.
        data: array or list of array, optional
            The data vector `ω`.
        hyper: float, optional
            The hyperparameter `μ`.

        Notes
        -----
        For implementation issue, `operator` and `adjoint` are wrapped by
        methods of same name.

        If `data` is a list of array, `operator` must return a similar list with
        arrays of same shape, and `adjoint` must accept a similar list also.

        In that case, however, and for algorithm purpose, everything is
        internally stacked as a column vector and values are therefore copied.
        This is not efficient but flexible. Users are encouraged to do the
        vectorization themselves and not use the list of array feature.
        """
        super().__init__()
        self._operator = operator
        self._adjoint = adjoint

        if isinstance(data, list):
            self._shape = [arr.shape for arr in data]
            self._idx = np.cumsum([0] + [arr.size for arr in data])
            self.data = self._list2vec(data)
        else:
            self.data = 0 if data is None else data

        self.hyper = hyper
        self.loss = loss

    @staticmethod
    def _list2vec(arr_list: Sequence[array]) -> array:  #  pylint: disable=no-self-use
        """Vectorize a list of array."""
        return np.vstack([arr.reshape((-1, 1)) for arr in arr_list])

    def _vec2list(self, arr: array) -> List[array]:
        """De-vectorize to a list of array."""
        return [
            np.reshape(arr[self._idx[i] : self._idx[i + 1]], shape)
            for i, shape in enumerate(self._shape)
        ]

    def operator(self, point: array) -> array:
        """Return `V·x`."""
        if hasattr(self, "_shape"):
            return self._list2vec(self._operator(point))
        return self._operator(point)

    def adjoint(self, point: array) -> array:
        """Return `Vᵀ·x`."""
        if hasattr(self, "_shape"):
            return self._adjoint(self._vec2list(point))
        return self._adjoint(point)

    def value(self, point: array) -> float:
        """The value of the objective function at given point

        Return `μ ψ(V·x - ω)`.
        """
        return self.hyper * np.sum(self.loss(self.operator(point) - self.data))

    def gradient(self, point: array) -> array:
        """The gradient and value at given point

        Return `μ Vᵀ·φ'(V·x - ω)`.
        """
        residual = self.operator(point) - self.data
        if self.calc_objv:
            self.lastv = self.hyper * np.sum(self.loss(residual))
        return self.hyper * self.adjoint(self.loss.gradient(residual))

    def norm_mat_major(self, vecs: array, point: array) -> array:
        matrix = vecs.T @ (self.gr_coeffs(point).reshape((-1, 1)) * vecs)
        return float(matrix) if matrix.size == 1 else matrix

    def gr_coeffs(self, point: array) -> array:
        """The Geman & Reynolds coefficients at given point

        Given `x` return `φ'(V·x - ω) / (V·x - ω)`
        """
        obj = self.operator(point) - self.data
        return self.loss.gr_coeffs(obj)

    def __call__(self, point: array) -> float:
        return self.value(point)


class QuadObjective(Objective):
    r"""A quadratic objective function

    .. math::
        :nowrap:

        \begin{aligned}
        J(x) & = \frac{1}{2} \mu \|V x - \omega\|_B^2 \\
             & = \frac{1}{2} \mu (V x - \omega)^tB(V x - \omega) \\
        \end{aligned}

    data : array
        The `data` array, or the vectorized list of array given at init.
    hyper : float
        The hyperparameter value `μ`.
    data_t : array
        The retroprojected data `B·Vᵀ·ω`.
    """

    def __init__(  # pylint: disable=too-many-arguments
        self,
        operator: Callable[[array], ArrOrSeq],
        adjoint: Callable[[ArrOrSeq], array],
        hessp: Callable[[array], array] = None,
        data: array = None,
        hyper: float = 1,
        metric: array = None,
    ):
        """A quadratic objective `½ μ ||V·x - ω||²_B`

        Parameters
        ----------
        operator: callable
            A callable that compute the output `V·x`.
        adjoint: callable
            A callable that compute `Vᵀ·e`.
        hessp: callable, optional
            A callable that compute `Q·x` as `Q·x = VᵀV·x`
        data: array or list of array, optional
            The data vector `ω`.
        hyper: float, optional
            The hyperparameter `μ`.
        metric: array, optional
            The **diagonal** of the metric matrix `B`. Equivalent to Identity if
            not provided.

        Notes
        -----
        The `hessp` (`Q`) callable is used for gradient computation as `∇ = μ
        (Q·x - b)` where `b = B·Vᵀ·ω` instead of `∇ = μ Vᵀ·B·(V·x - ω)`. This is
        optional and in some case this is more efficient.

        The variable `b = B·Vᵀ·ω` is computed at object creation.

        """
        super().__init__(operator, adjoint, Square(), data=data, hyper=hyper)
        self._metric = metric

        if hessp is not None:
            self.hessp = lambda x: hyper * hessp(x)
        else:
            self.hessp = lambda x: hyper * adjoint(self._metricp(operator(x)))

        if data is None:
            self.data_t = 0
            self.constant = 0  # c = μ ωᵀ·B·ω
        else:
            self.data_t = hyper * self._metricp(adjoint(data))
            self.constant = hyper * np.sum(data * self._metricp(data))  # c = μ ωᵀ·B·ω

    def _metricp(self, arr: array) -> array:
        if self._metric is None:
            return arr
        return self._metric * arr

    def value(self, point: array) -> float:
        """The value of the objective function at given point

        Return `½ μ ||V·x - ω||²_B`.
        """
        return (
            self.hyper
            * np.sum(self._metricp((self.operator(point) - self.data) ** 2))
            / 2
        )

    def gradient(self, point: array) -> array:
        """The gradient and value at given point

        Return `∇ = μ (Q·x - b) = μ Vᵀ·B·(V·x - ω)`.

        Notes
        -----
        Objective value is computed with low overhead thanks to the relation

        `J(x) = ½ (xᵀ·∇ - xᵀ·b + μ ωᵀ·B·ω)`
        """
        Qx = self.hessp(point)
        if self.calc_objv:
            self.lastv = self._value_hessp(point, Qx)
        return self.hessp(point) - self.data_t

    def _value_hessp(self, point, hessp):
        """Return `J(x)` value given `q = Qx`

        thanks to relation

        `J(x) =  ½ (xᵀ·q - 2 xᵀ·b + μ ωᵀ·B·ω)`"""
        return (
            # np.sum(point * hessp) - 2 * np.sum(point * self.data_t) + self.constant
            np.sum(point * (hessp - 2 * self.data_t))
            + self.constant
        ) / 2

    def value_residual(self, point, residual):
        """Return `J(x)` value given `x` and `r = b - Qx`

        thanks to relation

        `J(x) =  ½ (xᵀ·(-b - r) + μ ωᵀ·B·ω)`"""
        return (np.sum(point * (-self.data_t - residual)) + self.constant) / 2

    def norm_mat_major(self, vecs: array, point: array) -> array:
        return vecs.T @ vecs

    def gr_coeffs(self, point: array) -> array:
        """Return 1."""
        return 1

    def __call__(self, point: array) -> float:
        return self.value(point)


class Vmin(BaseObjective):
    r"""A minimum value objective function

    .. math::

        J(x) = \frac{1}{2} \mu \|P_{]-\infty, m]}(x) - m\|_2^2.

    vmin : float
        The minimum value `m`.
    hyper : float
        The hyperparameter value `μ`.
    """

    def __init__(self, vmin: float, hyper: float):
        """A minimum value objective function

        `J(x) = ½ μ ||P_[m, +∞[(x) - m||²`.

        Parameters
        ----------
        vmin : float
            The minimum value `m`.
        hyper : float
            The hyperparameter value `μ`.
        """
        super().__init__()
        self.vmin = vmin
        self.hyper = hyper

    def operator(self, point):
        return point[point <= self.vmin]

    def value(self, point: array) -> array:
        """Return the value at current point."""
        return self.hyper * np.sum((point[point <= self.vmin] - self.vmin) ** 2) / 2

    def gradient(self, point: array) -> array:
        idx = point <= self.vmin
        if self.calc_objv:
            self.lastv = self.hyper * np.sum((point[idx] - self.vmin) ** 2) / 2
        return self.hyper * np.where(idx, point - self.vmin, 0)

    def norm_mat_major(self, vecs: array, point: array) -> array:
        return vecs.T @ vecs


class Vmax(BaseObjective):
    r"""A maximum value objective function

    .. math::

        J(x) = \frac{1}{2} \mu \|P_{[M, +\infty[}(x) - m\|_2^2.

    vmax : float
        The maximum value `M`.
    hyper : float
        The hyperparameter value `μ`.
    """

    def __init__(self, vmax: float, hyper: float):
        """A maximum value objective function

        Return `J(x) = ½ μ ||P_[M, +∞[(x) - M||²`.

        Parameters
        ----------
        vmax : float
            The maximum value `M`.
        hyper : float
            The hyperparameter value `μ`.
        """
        super().__init__()
        self.vmax = vmax
        self.hyper = hyper

    def operator(self, point):
        return point[point >= self.vmax]

    def value(self, point: array) -> array:
        """Return the value at current point."""
        return self.hyper * np.sum((point[point >= self.vmax] - self.vmax) ** 2) / 2

    def gradient(self, point: array) -> array:
        idx = point >= self.vmax
        if self.calc_objv:
            self.lastv = self.hyper * np.sum((point[idx] - self.vmax) ** 2) / 2
        return self.hyper * np.where(idx, point - self.vmax, 0)

    def norm_mat_major(self, vecs: array, point: array) -> array:
        return vecs.T @ vecs


class Loss(abc.ABC):
    """An abstract base class for loss `φ`.

    The class has the following attributes.

    inf : float
      The value of `lim_{u→0} φ'(u) / u`.
    convex : boolean
      A flag indicating if the loss is convex (not used).
    coercive : boolean
      A flag indicating if the loss is coercive (not used).
    """

    def __init__(self, inf: float, convex: bool = False, coercive: bool = False):
        """The loss φ

        Parameters
        ----------
        inf : float
          The value of `lim_{u→0} φ'(u) / u`.
        convex : boolean
          A flag indicating if the loss is convex.
        coercive : boolean
          A flag indicating if the loss is coercive.
        """
        self.inf = inf
        self.convex = convex
        self.coercive = coercive

    @abc.abstractmethod
    def value(self, point: array) -> array:
        """The value `φ(·)` at given point."""
        return NotImplemented

    @abc.abstractmethod
    def gradient(self, point: array) -> array:
        """The gradient `φ'(·)` at given point."""
        return NotImplemented

    def gr_coeffs(self, point: array) -> array:
        """The Geman & Reynolds `φ'(·)/·` coefficients at given point."""
        aux = self.inf * np.ones_like(point)
        idx = point != 0
        aux[idx] = self.gradient(point[idx]) / point[idx]
        return aux

    def __call__(self, point: array) -> array:
        """The value at given point."""
        return self.value(point)


class Square(Loss):
    r"""The Square loss

    .. math::

       \varphi(u) = \frac{1}{2} u^2.
    """

    def __init__(self):
        """The Square loss `φ(u) = ½ u²`."""
        super().__init__(inf=1, convex=True, coercive=True)

    def value(self, point: array) -> array:
        return point ** 2 / 2

    def gradient(self, point: array) -> array:
        return point

    def __repr__(self):
        return """φ(u) = ½ u²
"""


class Huber(Loss):
    r"""The convex coercive Huber loss

    .. math::

       \varphi(u) =
       \begin{cases}
          \frac{1}{2} u^2 & \text{, if } u \leq \delta, \\
          \delta |u| - \frac{\delta^2}{2} & \text{, otherwise.}
       \end{cases}

    """

    def __init__(self, delta: float):
        """The Huber loss."""
        super().__init__(inf=1, convex=True, coercive=True)
        self.delta = delta

    def value(self, point: array) -> array:
        return np.where(
            np.abs(point) <= self.delta,
            point ** 2 / 2,
            self.delta * (np.abs(point) - self.delta / 2),
        )

    def gradient(self, point: array) -> array:
        return np.where(np.abs(point) <= self.delta, point, self.delta * np.sign(point))

    def __repr__(self):
        return f"""{type(self)}

       ⎛
       ⎜ ½ u²        , if |u| < δ
φ(u) = ⎜
       ⎜ δ|u| - δ²/2 , otherwise.
       ⎝

with δ = {self.delta}
"""


class Hyperbolic(Loss):
    r"""The convex coercive hyperbolic loss

    .. math::

       \varphi(u) = \delta^2 \left( \sqrt{1 + \frac{u^2}{\delta^2}} -1 \right)

    This is sometimes called Pseudo-Huber.
    """

    def __init__(self, delta: float):
        """The hyperbolic loss."""
        super().__init__(inf=1, convex=True, coercive=True)
        self.delta = delta

    def value(self, point: array) -> array:
        return self.delta ** 2 * (np.sqrt(1 + (point ** 2) / (self.delta ** 2)) - 1)

    def gradient(self, point: array) -> array:
        return point / np.sqrt(1 + (point ** 2) / self.delta ** 2)

    def __repr__(self):
        return f"""{type(self)}
               _______
          ⎛   ╱     u²     ⎞
φ(u) = δ²⋅⎜  ╱  1 + ──  - 1⎟
          ⎝╲╱       δ²     ⎠


with δ = {self.delta}
"""


class HebertLeahy(Loss):
    r"""The non-convex coercive Hebert & Leahy loss

    .. math::

       \varphi(u) = \log \left(1 + \frac{u^2}{\delta^2} \right)

    """

    def __init__(self, delta: float):
        """The Hebert & Leahy loss."""
        super().__init__(inf=2 / delta ** 2, convex=False, coercive=True)
        self.delta = delta

    def value(self, point: array) -> array:
        return np.log(1 + point ** 2 / self.delta ** 2)

    def gradient(self, point: array) -> array:
        return 2 * point / (self.delta ** 2 + point ** 2)

    def __repr__(self):
        return f"""{type(self)}

          ⎛    u²⎞
φ(u) = log⎜1 + ──⎟
          ⎝    δ²⎠

with δ = {self.delta}
"""


class GemanMcClure(Loss):
    r"""The non-convex non-coervice Geman & Mc Clure loss

    .. math::

       \varphi(u) = \frac{u^2}{2\delta^2 + u^2}

    """

    def __init__(self, delta: float):
        r"""The Geman & Mc Clure loss."""
        super().__init__(1 / (delta ** 2), convex=False, coercive=False)
        self.delta = delta

    def value(self, point: array) -> array:
        return point ** 2 / (2 * self.delta ** 2 + point ** 2)

    def gradient(self, point: array) -> array:
        return 4 * point * self.delta ** 2 / (2 * self.delta ** 2 + point ** 2) ** 2

    def __repr__(self):
        return f"""{type(self)}

          u²
φ(u) = ─────────
       u² + 2⋅δ²

with δ = {self.delta}
"""


class TruncSquareApprox(Loss):
    r"""The non-convex non-coercive truncated square approximation

    .. math::

       \varphi(u) = 1 - \exp \left(- \frac{u^2}{2\delta^2} \right)

    """

    def __init__(self, delta: array):
        """The truncated square approximation."""
        super().__init__(inf=1 / (delta ** 2), convex=False, coercive=False)
        self.delta = delta

    def value(self, point: array) -> array:
        return 1 - np.exp(-(point ** 2) / (2 * self.delta ** 2))

    def gradient(self, point: array) -> array:
        return point / (self.delta ** 2) * np.exp(-(point ** 2) / (2 * self.delta ** 2))

    def __repr__(self):
        return f"""{type(self)}

               u²
            - ────
              2⋅δ²
φ(u) = 1 - e

with δ = {self.delta}
"""


### Local Variables:
### ispell-local-dictionary: "english"
### End:
