#!/usr/bin/env python3
# -*- coding: utf-8 -*-

__author__ = ["Juliette Chabassier", "Augustin Ernoult", "Olivier Geber",
              "Alexis Thibault", "Tobias Van Baarsel"]
__copyright__ = "Copyright 2020, Inria"
__credits__ = ["Juliette Chabassier", "Augustin Ernoult", "Olivier Geber",
               "Alexis Thibault", "Tobias Van Baarsel"]
__license__ = "GPL 3.0"
__version__ = "0.4"
__email__ = "openwind-contact@inria.fr"
__status__ = "Dev"
"""
Pipe the radius of which follows a Bessel horn equation.
"""


from openwind.design import DesignShape, eval_, diff_
import numpy as np

def bessel(x, x1, x2, r1, r2, alpha):
    """Calculate images with a bessel function between 2 points.

    The radius of the pipe follows the equation

    .. math::
        \\begin{eqnarray}
        r(x) & = & r_1  \left( \\frac{x_1 - x_p}{x - x_p} \\right)^{\\alpha} \\\\
        x_p & = & \\frac{x_1 - R x_2}{1 - R} \\\\
        R & = & \\left( \\frac{r_2}{r_1} \\right)^{1/\\alpha}
        \\end{eqnarray}

    with:

    - \(x_1, x_2\): the endpoints positions of the pipe
    - \(r_1, r_2\): the endpoints radii of the pipe
    - \(\\alpha\): the Bessel coefficient

    Parameters
    ----------
    x : float, array of float
        the point at which the value of r is calculated
    x1, r1 : float
        the first point
    x2, r2 : float
        the second point
    alpha : float or list
        coefficient of the bessel: power
    """
    if r1==r2:
        raise ValueError("A bessel can not have to 2 equal radii (r1!=r2)")
    rr = (r2/r1)**(1/alpha)  # radius ratio
    xp = (x1 - rr*x2)/(1 - rr)  # abscissa of the pole
    return r1 * ((x1 - xp) / (x - xp))**alpha


def dbessel_dx(x, x1, x2, r1, r2, alpha):
    """Differentiate with respect to x.

    .. math::
        \\frac{\partial r(x))}{\partial x} \


    Parameters
    ----------
    x : float, array of float
        the point at which the value of y is calculated
    x1, r1 : float
        the first point
    x2, r2 : float
        the second point
    alpha : float or list
        coefficient of the bessel: power
    """
    rr = (r2/r1)**(1/alpha)
    xp = (x1 - rr*x2)/(1 - rr)
    return -alpha/(x-xp) * r1 * ((x1 - xp) / (x - xp))**alpha


def dbessel_dxp(x, x1, x2, r1, r2, alpha):
    """Differentiate with respect to xp.

    .. math::
        \\frac{\partial r(x))}{\partial x_p} \


    Parameters
    ----------
    x : float, array of float
        the point at which the value of y is calculated
    x1, r1 : float
        the first point
    x2, r2 : float
        the second point
    alpha : float or list
        coefficient of the bessel: power
    """
    rr = (r2/r1)**(1/alpha)
    xp = (x1 - rr*x2)/(1 - rr)
    dy_dxp = r1*alpha*(x1 - x)/(x - xp)**2 * ((x1 - xp)/(x - xp))**(alpha - 1)
    return dy_dxp


def dbessel_drr(x, x1, x2, r1, r2, alpha):
    """Differentiate with respect to rr.

    .. math::
        \\frac{\partial r(x))}{\partial R} \


    Parameters
    ----------
    x : float, array of float
        the point at which the value of y is calculated
    x1, r1 : float
        the first point
    x2, r2 : float
        the second point
    alpha : float or list
        coefficient of the bessel: power
    """
    rr = (r2/r1)**(1/alpha)
    dxp_drr = (x1 - x2) / (1 - rr)**2
    dy_dxp = dbessel_dxp(x, x1, x2, r1, r2, alpha)
    return dy_dxp * dxp_drr


def dbessel_dr1(x, x1, x2, r1, r2, alpha):
    """Differentiate with respect to r1.

    .. math::
        \\frac{\partial r(x))}{\partial r_1} \


    Parameters
    ----------
    x : float, array of float
        the point at which the value of y is calculated
    x1, r1 : float
        the first point
    x2, r2 : float
        the second point
    alpha : float or list
        coefficient of the bessel: power
    """
    rr = (r2/r1)**(1/alpha)
    xp = (x1 - rr*x2)/(1 - rr)

    drr_dr1 = -1/(alpha*r1) * rr
    dy_drr = dbessel_drr(x, x1, x2, r1, r2, alpha)
    dy_dr1 = ((x1 - xp) / (x - xp))**alpha
    return dy_drr * drr_dr1 + dy_dr1


def dbessel_dr2(x, x1, x2, r1, r2, alpha):
    """Differentiate with respect to r2.

    .. math::
        \\frac{\partial r(x))}{\partial r_2} \


    Parameters
    ----------
    x : float, array of float
        the point at which the value of y is calculated
    x1, r1 : float
        the first point
    x2, r2 : float
        the second point
    alpha : float or list
        coefficient of the bessel: power
    """
    rr = (r2/r1)**(1/alpha)
    drr_dr2 = 1/(alpha*r2) * rr
    dy_drr = dbessel_drr(x, x1, x2, r1, r2, alpha)
    return dy_drr * drr_dr2


def dbessel_dx1(x, x1, x2, r1, r2, alpha):
    """Differentiate with respect to x1.

    .. math::
        \\frac{\partial r(x))}{\partial x_1} \


    Parameters
    ----------
    x : float, array of float
        the point at which the value of y is calculated
    x1, r1 : float
        the first point
    x2, r2 : float
        the second point
    alpha : float or list
        coefficient of the bessel: power
    """
    rr = (r2/r1)**(1/alpha)
    xp = (x1 - rr*x2)/(1 - rr)
    dxp_dx1 = 1/(1 - rr)
    dy_dxp = dbessel_dxp(x, x1, x2, r1, r2, alpha)
    dy_dx1 = alpha*r1/(x1 - xp) * ((x1 - xp)/(x - xp))**alpha
    return dy_dxp * dxp_dx1 + dy_dx1


def dbessel_dx2(x, x1, x2, r1, r2, alpha):
    """Differentiate with respect to x2.

    .. math::
        \\frac{\partial r(x))}{\partial x_2} \


    Parameters
    ----------
    x : float, array of float
        the point at which the value of y is calculated
    x1, r1 : float
        the first point
    x2, r2 : float
        the second point
    alpha : float or list
        coefficient of the bessel: power
    """
    rr = (r2/r1)**(1/alpha)
    dxp_dx2 = -rr/(1 - rr)
    dy_dxp = dbessel_dxp(x, x1, x2, r1, r2, alpha)
    return dy_dxp * dxp_dx2


def dbessel_dalpha(x, x1, x2, r1, r2, alpha):
    """Differentiate with respect to alpha.

    .. math::
        \\frac{\partial r(x))}{\partial \\alpha} \


    Parameters
    ----------
    x : float, array of float
        the point at which the value of y is calculated
    x1, r1 : float
        the first point
    x2, r2 : float
        the second point
    alpha : float or list
        coefficient of the bessel: power
    """
    rr = (r2/r1)**(1/alpha)
    xp = (x1 - rr*x2)/(1 - rr)
    A = (x1 - xp)/(x - xp)

    drr_dalpha = -1/alpha**2 * np.log(r2/r1) * rr
    dy_dalpha = r1 * np.log(A) * A**alpha
    dy_drr = dbessel_drr(x, x1, x2, r1, r2, alpha)

    return dy_drr*drr_dalpha + dy_dalpha


class Bessel(DesignShape):
    """
    Pipe the radius of which follows a Bessel horn equation.

    The radius of the pipe follows the equation

    .. math::
        \\begin{eqnarray}
        r(x) & = & r_1  \left( \\frac{x_1 - x_p}{x - x_p} \\right)^{\\alpha} \\\\
        x_p & = & \\frac{x_1 - R x_2}{1 - R} \\\\
        R & = & \\left( \\frac{r_2}{r_1} \\right)^{1/\\alpha}
        \\end{eqnarray}

    with:

    - \(x_1, x_2\): the endpoints positions of the pipe
    - \(r_1, r_2\): the endpoints radii of the pipe
    - \(\\alpha\): the Bessel coefficient



    Parameters
    ----------
    *params : 5 openwind.design.design_parameter.DesignParameter
        The five parameters in this order: \(x_1, x_2, r_1, r_2, \\alpha\)
    """

    def __init__(self, *params):
        if len(params) != 5:
            raise ValueError("A bessel shape need 5 parameters.")
        if params[2] == params[3]:
            raise ValueError("A bessel can not have to 2 equal radii (r1!=r2)")

        self.params = params

    def __str__(self):
        geom = ''
        for param in self.params[:-1]:
            geom += '{} '.format(param)
        return '{geom}{class_} {parameter}'.format(geom=geom,
                                                   parameter=self.params[-1],
                                                   class_=type(self).__name__)

    def get_radius_at(self, x_norm):
        x1, x2, r1, r2, alpha = eval_(self.params)
        x = self.get_position_from_xnorm(x_norm)
        radius = bessel(x, x1, x2, r1, r2, alpha)
        self.check_bounds(x, [x1, x2])
        return radius

    def get_diff_radius_at(self, x_norm, diff_index):
        x1, x2, r1, r2, alpha = eval_(self.params)
        dx1, dx2, dr1, dr2, dalpha = diff_(self.params, diff_index)
        dx_norm = self.get_diff_position_from_xnorm(x_norm, diff_index)
        x = self.get_position_from_xnorm(x_norm)
        diff_radius = dbessel_dx(x, x1, x2, r1, r2, alpha)*dx_norm
        if dx1 != 0:
            diff_radius += dx1*dbessel_dx1(x, x1, x2, r1, r2, alpha)
        if dx2 != 0:
            diff_radius += dx2*dbessel_dx2(x, x1, x2, r1, r2, alpha)
        if dr1 != 0:
            diff_radius += dr1*dbessel_dr1(x, x1, x2, r1, r2, alpha)
        if dr2 != 0:
            diff_radius += dr2*dbessel_dr2(x, x1, x2, r1, r2, alpha)
        if dalpha != 0:
            diff_radius += dalpha*dbessel_dalpha(x, x1, x2, r1, r2, alpha)
        self.check_bounds(x, [x1, x2])
        return diff_radius

    def get_endpoints_position(self):
        return self.params[0], self.params[1]

    def get_endpoints_radius(self):
        return self.params[2], self.params[3]

    def get_diff_shape_wr_x_norm(self, x_norm):
        x1, x2, r1, r2, alpha = eval_(self.params)
        x = self.get_position_from_xnorm(x_norm)
        self.check_bounds(x, [x1, x2])
        dradius = dbessel_dx(x, x1, x2, r1, r2, alpha)
        return dradius * self.get_length()
