#! /usr/bin/env python
# S. Morisaki, based on his
# https://dcc.ligo.org/LIGO-T2100485

import argparse
import numpy as np
from scipy import integrate
from scipy.special import erfcinv, erf, erfcx

import RIFT.likelihood.factored_likelihood as factored_likelihood


parser = argparse.ArgumentParser()
parser.add_argument("--d-min", default=1, type=float, help="Minimum distance in volume integral. Used to SET THE PRIOR; changing this value changes the numerical answer.")
parser.add_argument("--d-max", default=10000, type=float, help="Maximum distance in volume integral. Used to SET THE PRIOR; changing this value changes the numerical answer.")
parser.add_argument("--max-snr", default=1e4, type=float, help="Maximum SNR at a reference distance.")
parser.add_argument("--relative-error", default=1e-3, type=float, help="Relative error of numerical integration.")
parser.add_argument("--out", default="distance_marginalization_lookup.npz", help="Output file (format should be .npz)")
opts=  parser.parse_args()

dmin = opts.d_min
dmax = opts.d_max
dref = factored_likelihood.distMpcRef
# x = Dref / D
xmin = dref / dmax
xmax = dref / dmin
x0min = -10. * xmax
x0max = 10. * xmax
bmax = opts.max_snr
bref = 1. / ((xmax - xmin) * max(xmin + xmax - 2. * x0min, 2. * x0max - xmin - xmax))
delta_s = 0.1
delta_t = 0.1
eps = opts.relative_error


def exponent(x, x0, b):
    return b / 2. * (x0**2. - (x - x0)**2.)


@np.vectorize
def lnI(x0, b):
    if x0 < xmin:
        exponent_max = exponent(xmin, x0, b)
        xmin_integral = xmin
        # determine the upper limit of the integral
        alpha1 = np.sqrt(b / 2.) * (xmin - x0 + 4. / (b * xmin))
        alpha2 = np.sqrt(b / 2.) * (xmax - x0 + 4. / (b * xmin))
        tmp = eps * (erfcx(alpha1) - np.exp(-(alpha2**2. - alpha1**2.)) * erfcx(alpha2))
        if 0. < tmp < 1.:
            A = -np.log(tmp)
            dx = 2. * A / (np.sqrt(b**2. * (xmin - x0)**2. + 2. * b * A) + b * (xmin - x0))
            xmax_integral = min(xmin + dx, xmax)
        else:
            xmax_integral = xmax
    elif x0 > xmax:
        exponent_max = exponent(xmax, x0, b)
        xmax_integral = xmax
        # determine the lower limit of the integral
        gamma1 = np.sqrt(b / 2.) * (x0 - xmax - 4. / (b * xmax))
        gamma2 = np.sqrt(b / 2.) * (x0 - xmin - 4. / (b * xmax))
        tmp = eps * xmin**4. / xmax**4. * (erfcx(gamma1) - np.exp(-(gamma2**2. - gamma1**2.)) * erfcx(gamma2))
        if 0. < tmp < 1.:
            C = -np.log(tmp)
            dx = 2. * C / (np.sqrt(b**2. * (x0 - xmax)**2. + 2. * b * C) + b * (x0 - xmax))
            xmin_integral = max(xmax - dx, xmin)
        else:
            xmin_integral = xmin
    else:
        exponent_max = exponent(x0, x0, b)
        # determine the lower and upper limits of the integral
        beta1 = np.sqrt(b / 2.) * (xmin - x0 + 4. / (b * x0))
        beta2 = np.sqrt(b / 2.) * (xmax - x0 + 4. / (b * x0))
        tmp = eps / 2. * (xmin / x0)**4. * (erf(beta2) - erf(beta1))
        if 0. < tmp < 1.:
            dx = np.sqrt(2. / b) * erfcinv(tmp)
            xmin_integral, xmax_integral = max(x0 - dx, xmin), min(x0 + dx, xmax)
        else:
            xmin_integral, xmax_integral = xmin, xmax
    result, _ = integrate.quad(lambda x, x0, b: x**(-4.) * np.exp(exponent(x, x0, b) - exponent_max),
                               xmin_integral, xmax_integral, args=(x0, b), epsabs=-1, epsrel=eps)
    return np.log(result)


def x0_to_s(x0):
    return np.arcsinh(np.sqrt(bmax) * (x0 - xmin)) - np.arcsinh(np.sqrt(bmax) * (xmax - x0))


@np.vectorize
def s_to_x0(s):
    assert smin <= s <= smax
    x0low = x0min
    x0high = x0max
    slow = x0_to_s(x0low)
    shigh = x0_to_s(x0high)
    # bisection search
    x0mid = (x0low + x0high) / 2.
    while shigh - slow > 1e-2 * delta_s:
        smid = x0_to_s(x0mid)
        if smid > s:
            x0high = x0mid
            shigh = smid
        else:
            x0low = x0mid
            slow = smid
        x0mid = (x0low + x0high) / 2.
    return x0mid


def b_to_t(b):
    return np.arcsinh(b / bref)


def t_to_b(t):
    return bref * np.sinh(t)


smin = x0_to_s(x0min)
smax = x0_to_s(x0max)
tmax = b_to_t(bmax)
s_array = np.linspace(smin, smax, int((smax - smin) / delta_s))
t_array = np.linspace(0., tmax, int(tmax / delta_t))
x0_array = np.append(x0min, np.append(s_to_x0(s_array[1:-1]), x0max))
b_array = t_to_b(t_array)
x0_grid, b_grid = np.meshgrid(x0_array, b_array, indexing="ij")
lnI_array = np.ones(shape=x0_grid.shape)
lnI_array[:, 0] *= np.log((xmin**(-3.) - xmax**(-3.)) / 3.)  # analytical solution for b=0
lnI_array[:, 1:] = lnI(x0_grid[:, 1:], b_grid[:, 1:])

# save the list
np.savez(opts.out, phase_marginalization=False,
         bmax=bmax, bref=bref, x0min=x0min, x0max=x0max,
         s_array=s_array, t_array=t_array, lnI_array=lnI_array)
