#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2016-2022 Stéphane Caron and the qpsolvers contributors.
#
# This file is part of qpsolvers.
#
# qpsolvers is free software: you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License as published by the Free
# Software Foundation, either version 3 of the License, or (at your option) any
# later version.
#
# qpsolvers is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
# A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with qpsolvers. If not, see <http://www.gnu.org/licenses/>.

"""
Tests for the `solve_lp` function.
"""

import unittest
import warnings

from numpy import array, dot
from numpy.linalg import norm

from qpsolvers import available_solvers, solve_ls
from qpsolvers.exceptions import NoSolverSelected, SolverNotFound


def solve_ls_with_test_params(*args, **kwargs):
    """
    Call ``solve_lp`` with additional solver parameters
    """
    params = {}
    if kwargs["solver"] == "proxqp":
        params["eps_abs"] = 1e-9
    kwargs.update(params)
    return solve_ls(*args, **kwargs)


class TestSolveLS(unittest.TestCase):

    """
    Test fixture for the README example problem.
    """

    def setUp(self):
        """
        Prepare test fixture.
        """
        warnings.simplefilter("ignore", category=DeprecationWarning)
        warnings.simplefilter("ignore", category=UserWarning)
        self.R = array([[1.0, 2.0, 0.0], [2.0, 3.0, 4.0], [0.0, 4.0, 1.0]])
        self.s = array([3.0, 2.0, 3.0])
        self.G = array([[1.0, 2.0, 1.0], [2.0, 0.0, 1.0], [-1.0, 2.0, -1.0]])
        self.h = array([3.0, 2.0, -2.0]).reshape((3,))
        self.A = array([1.0, 1.0, 1.0])
        self.b = array([1.0])

    def get_problem(self):
        """
        Get problem as a sextuple of values to unpack.

        Returns
        -------
        R :
            Least-squares matrix.
        s :
            Least-squares vector.
        G :
            Linear inequality matrix.
        h :
            Linear inequality vector.
        A :
            Linear equality matrix.
        b :
            Linear equality vector.
        """
        return self.R, self.s, self.G, self.h, self.A, self.b

    @staticmethod
    def get_test(solver: str):
        """
        Get test function for a given solver.

        Parameters
        ----------
        solver :
            Name of the solver to test.

        Returns
        -------
        :
            Test function for that solver.
        """

        def test(self):
            R, s, G, h, A, b = self.get_problem()
            x = solve_ls_with_test_params(R, s, G, h, A, b, solver=solver)
            x_sp = solve_ls_with_test_params(
                R, s, G, h, A, b, solver=solver, sym_proj=True
            )
            self.assertIsNotNone(x)
            self.assertIsNotNone(x_sp)
            known_solution = array([2.0 / 3, -1.0 / 3, 2.0 / 3])
            sol_tolerance = (
                5e-3
                if solver == "osqp"
                else 1e-5
                if solver == "ecos"
                else 1e-6
            )
            eq_tolerance = 1e-9
            ineq_tolerance = (
                1e-3 if solver == "osqp" else 2e-7 if solver == "scs" else 1e-9
            )
            self.assertLess(norm(x - known_solution), sol_tolerance)
            self.assertLess(norm(x_sp - known_solution), sol_tolerance)
            self.assertLess(max(dot(G, x) - h), ineq_tolerance)
            self.assertLess(max(dot(A, x) - b), eq_tolerance)
            self.assertLess(min(dot(A, x) - b), eq_tolerance)

        return test

    def test_no_solver_selected(self):
        """
        Check that NoSolverSelected is raised when applicable.
        """
        R, s, G, h, A, b = self.get_problem()
        with self.assertRaises(NoSolverSelected):
            solve_ls_with_test_params(R, s, G, h, A, b, solver=None)

    def test_solver_not_found(self):
        """
        Check that SolverNotFound is raised when the solver does not exist.
        """
        R, s, G, h, A, b = self.get_problem()
        with self.assertRaises(SolverNotFound):
            solve_ls_with_test_params(R, s, G, h, A, b, solver="ideal")


# Generate test fixtures for each solver
for solver in available_solvers:
    setattr(
        TestSolveLS, "test_{}".format(solver), TestSolveLS.get_test(solver)
    )
