#! /usr/bin/env python
# S. Morisaki, based on his
# https://dcc.ligo.org/LIGO-T2100485

import argparse
import numpy as np
from scipy import integrate
from scipy.special import erfcinv, erf, erfcx, i0e
import warnings
import functools

import RIFT.likelihood.factored_likelihood as factored_likelihood


parser = argparse.ArgumentParser()
parser.add_argument("--d-min", default=1, type=float, help="Minimum distance in volume integral. Used to SET THE PRIOR; changing this value changes the numerical answer.")
parser.add_argument("--d-max", default=10000, type=float, help="Maximum distance in volume integral. Used to SET THE PRIOR; changing this value changes the numerical answer.")
parser.add_argument("--max-snr", default=1e4, type=float, help="Maximum SNR at a reference distance.")
parser.add_argument("--likelihood-threshold", default=1e-30, type=float, help="Likelihood threshold to determine distance integration range. Distance range where likelihood is smaller than its maximum multiplied by the threshold is discarded.")
parser.add_argument("--hermgauss-degree", default=50, type=int, help="Degree of Gauss-Hermite quadrature used for high SNR.")
parser.add_argument("--laggauss-degree", default=50, type=int, help="Degree of Gauss-Laguerre quadrature used for high SNR.")
parser.add_argument("--out", default="distance_marginalization_lookup.npz", help="Output file (format should be .npz)")
parser.add_argument("--phase-marginalization", default=False, action="store_true", help="Analytical phase marginalization is used if True. Applicable only for 2-2-mode model.")
parser.add_argument("--d-prior",default='Euclidean' ,type=str,help="Distance prior for dL.  Options are dL^2 (Euclidean) and 'pseudo-cosmo'  .")
opts=  parser.parse_args()

dmin = opts.d_min
dmax = opts.d_max
dref = factored_likelihood.distMpcRef
# x = Dref / D
xmin = dref / dmax
xmax = dref / dmin
if opts.phase_marginalization:
    x0min = 0.
else:
    x0min = -10. * xmax
x0max = 10. * xmax
bmax = opts.max_snr**2
bref = 1. / ((xmax - xmin) * max(xmin + xmax - 2. * x0min, 2. * x0max - xmin - xmax))
log_likelihood_threshold = np.log(opts.likelihood_threshold)
delta_s = 0.1
delta_t = 0.1


def prior(d):
    """Distance prior. Currently, only distance-squared prior is supported."""
    return 3 * d**2 / (dmax**3 - dmin**3)

if opts.d_prior =='pseudo_cosmo':
    # exactly as in ILE itself
    import RIFT.likelihood.priors_utils as priors_utils
    nm = priors_utils.dist_prior_pseudo_cosmo_eval_norm(opts.d_min,opts.d_max)
    dist_prior_pdf =functools.partial( priors_utils.dist_prior_pseudo_cosmo, nm=nm,xpy=np) # doesn't matter here!

    # assign
    prior = dist_prior_pdf


def effective_prior(x, a):
    """Prior of x=dref/d. The exponentially scaled Bessel function is multiplied if phase marginalization is used."""
    tmp = dref / x**2 * prior(dref / x)
    if opts.phase_marginalization:
        tmp *= i0e(a * x)
    return tmp


def exponent(x, x0, b):
    return b / 2. * (x0**2. - (x - x0)**2.)


def get_max_exponent(x0, b):
    return exponent(np.clip(x0, xmin, xmax), x0, b)


def get_integration_range(x0, b):
    """Return distance integration range where likelihood is smaller than the threshold"""
    max_exponent = get_max_exponent(x0, b)
    if b > 0.:
        tmp = np.sqrt(x0**2 - 2 / b * (max_exponent + log_likelihood_threshold))
        return max(x0 - tmp, xmin), min(x0 + tmp, xmax)
    else:
        return xmin, xmax


hermgauss_samples, hermgauss_weights = np.polynomial.hermite.hermgauss(opts.hermgauss_degree)
laggauss_samples, laggauss_weights = np.polynomial.laguerre.laggauss(opts.laggauss_degree)


@np.vectorize
def lnI(x0, b):
    """Calculate logarithm of marginalized likelihood subtracted by the max
    exponent. For large b, the Gaussian-Hermite quadrature is used for xmin <
    x0 < xmax and the Gauss-Laguerre quadrature is used otherwise. For moderate
    b, scipy.integrate.quad is used."""
    if (
        xmin < x0 < xmax and
        b * x0**2 > 10000 and
        x0 + np.min(hermgauss_samples) * np.sqrt(2 / b) > xmin and
        x0 + np.max(hermgauss_samples) * np.sqrt(2 / b) < xmax
    ):
        result = np.sqrt(2 / b) * np.sum(
            effective_prior(np.sqrt(2 / b) * hermgauss_samples + x0, b * x0) *
            hermgauss_weights
        )
    elif (
        x0 < xmin and
        b * (xmin - x0) * xmin > 100 and
        x0 + np.sqrt(2 / b * np.max(laggauss_samples) + (xmin - x0)**2) < xmax
    ):
        tmp = np.sqrt(2 / b * laggauss_samples + (xmin - x0)**2)
        result = np.sum(
            laggauss_weights / (b * tmp) * effective_prior(x0 + tmp, b * x0)
        )
    elif (
        x0 > xmax and
        b * (x0 - xmax) * xmax > 100 and
        x0 - np.sqrt(2 / b * np.max(laggauss_samples) + (x0 - xmax)**2) > xmin
    ):
        tmp = np.sqrt(2 / b * laggauss_samples + (x0 - xmax)**2)
        result = np.sum(
            laggauss_weights / (b * tmp) * effective_prior(x0 - tmp, b * x0)
        )
    else:
        max_exponent = get_max_exponent(x0, b)
        xmin_integral, xmax_integral = get_integration_range(x0, b)
        result, _ = integrate.quad(
            lambda x, x0, b: effective_prior(x, b * x0) * np.exp(exponent(x, x0, b) - max_exponent),
            xmin_integral,
            xmax_integral,
            args=(x0, b)
        )
    return np.log(result)


def x0_to_s(x0):
    return np.arcsinh(np.sqrt(bmax) * (x0 - xmin)) - np.arcsinh(np.sqrt(bmax) * (xmax - x0))


@np.vectorize
def s_to_x0(s):
    assert smin <= s <= smax
    x0low = x0min
    x0high = x0max
    slow = x0_to_s(x0low)
    shigh = x0_to_s(x0high)
    # bisection search
    x0mid = (x0low + x0high) / 2.
    while shigh - slow > 1e-5 * delta_s:
        smid = x0_to_s(x0mid)
        if smid > s:
            x0high = x0mid
            shigh = smid
        else:
            x0low = x0mid
            slow = smid
        x0mid = (x0low + x0high) / 2.
    return x0mid


def b_to_t(b):
    return np.arcsinh(b / bref)


def t_to_b(t):
    return bref * np.sinh(t)


smin = x0_to_s(x0min)
smax = x0_to_s(x0max)
tmax = b_to_t(bmax)
s_array = np.linspace(smin, smax, int((smax - smin) / delta_s))
t_array = np.linspace(0., tmax, int(tmax / delta_t))
x0_array = np.append(x0min, np.append(s_to_x0(s_array[1:-1]), x0max))
b_array = t_to_b(t_array)
x0_grid, b_grid = np.meshgrid(x0_array, b_array, indexing="ij")
lnI_array = np.zeros(shape=x0_grid.shape)
lnI_array[:, 1:] = lnI(x0_grid[:, 1:], b_grid[:, 1:])  # lnI is vanishing for b=0, so don't compute it for the 0-th column.

# save the list
np.savez(opts.out, phase_marginalization=opts.phase_marginalization,
         bmax=bmax, bref=bref, x0min=x0min, x0max=x0max,
         dmin=dmin,dmax=dmax,
         s_array=s_array, t_array=t_array, lnI_array=lnI_array)
