#Filename: lsst_func.py
#Description: This module calculates useful cosmology functions

import sys,math,numpy as np,scipy
from scipy import interpolate
from scipy import integrate
from scipy.integrate import odeint
from functools import partial
import cosmo_params,glob

####################################################
#		COSMOLOGY
####################################################

#---------------------------------
#FUNCTION rhocrit
#Critical density as a function
#of redshift
#---------------------------------
def rhocrit(z):

    rhoc = cosmo_params.rhocrit0/cosmo_params.H0**2
    rhoc = rhoc*Hubble(z)**2

    return rhoc

#----------------------------------
#FUNCTION omegadez
#Dark energy density as a function
#of z
#----------------------------------
def omegadez(z):

    if((cosmo_params.w0!=-1.) or (cosmo_params.wa!=0.)):
        if(type(z) is list):
            i=0
            for z1 in z:
                omegadez[i]=(1+cosmo_params.w0+cosmo_params.wa)*math.log(1+z)-cosmo_params.wa*z/(1+z)
                i+=1
        else:
            omegadez=(1+cosmo_params.w0+cosmo_params.wa)*math.log(1+z)-cosmo_params.wa*z/(1+z)
        omegadez = cosmo_params.OmegaL*np.exp(3*omegadez)
    else:
        omegadez = cosmo_params.OmegaL
    omegadez = omegadez*cosmo_params.H0**2/Hubble(z)**2
    
    return omegadez

#----------------------------------
#FUNCTION darkw
#w(z) in the CPL parametrization
#----------------------------------
def darkw(x2):

    aexp = 1./(1+x2)
    darkw=cosmo_params.w0+cosmo_params.wa*(1-aexp)
    
    return darkw

#----------------------------------
#FUNCTION subintH
#This is an auxiliary function
#to be integrated to get H(z)
#----------------------------------
def subintH(x1):

    subintH = (1+darkw(x1))/(1.+x1)

    return subintH

#----------------------------------
#FUNCTION intHubble
#The integrand to get Hubble(z)
#----------------------------------
def intHubble(x):

    term1=cosmo_params.OmegaM*(1+x)**3.
    if((cosmo_params.w0==-1.) and (cosmo_params.wa==0.)):
        intH=0.
    else:
        intH = (1+cosmo_params.w0+cosmo_params.wa)*math.log(1+x)-cosmo_params.wa*x/(1+x)

    term2=cosmo_params.OmegaL*math.exp(3*intH)
    intHubble = cosmo_params.H0/cosmo_params.c*np.sqrt(term1+term2)

    return intHubble

#-----------------------------------
#FUNCTION Hubble(z)
#The Hubble constant (function of z)
#-----------------------------------
def Hubble(x):

    term1=cosmo_params.OmegaM*(1+x)**3.
    if((cosmo_params.w0==-1.) and (cosmo_params.wa==0.)):
        intH=0.
    else:
        intH = (1+cosmo_params.w0+cosmo_params.wa)*math.log(1+x)-cosmo_params.wa*x/(1+x)

    term2=cosmo_params.OmegaL*math.exp(3*intH)
    hub = cosmo_params.H0*np.sqrt(term1+term2)

    return hub

#-----------------------------------
#FUNCTION: integrandGr
#Integrand of the Growth function
#-----------------------------------
def integrandGr(y):
    integrandGr = 1.0/(y*math.sqrt(cosmo_params.OmegaM/y**3.+1.0-cosmo_params.OmegaM))**3.
    return integrandGr

#-----------------------------------------------
#FUNCTION intgrowthapprox
#Integrand of the growth function 0.1% approx
#-----------------------------------------------
def intgrowthapprox(a1):

    z1=1./a1-1
    Omz = cosmo_params.OmegaM*(1+z1)**3*cosmo_params.H0**2
    Omz = Omz/Hubble(z1)**2
    intg = 1./a1*(Omz**0.55-1)

    return intg

#-----------------------------------
#FUNCTION gode
#The ordinary diff eq. to be solved
#to obtain growth(z).
#This is the derivation:
#dg/dlna=adg/da
#d^2g/dlna^2=d/dlna(dg/dlna)=ad/da(adg/da) = adg/da+a^2d^2g/da^2
#y=dg/da
#2(ay+a^2dy/da)+[5-3w(a)Omega_DE(a)]*a*dg/da+3[1-w(a)]Omega-DE(a)g=0
#2(adg/da+a^2d^2g/da^2+[5-3w(a)Omega_DE(a)]*a*dg/da+3[1-w(a)]Omega-DE(a)g=0
#g = \{- 2[a^2 f'(a)+af(a)]-[5-3w(a)\Omega_{\rm DE}(z)]af(a)\}[3[1-w(a)]\Omega_{\rm DE}(a)]^{-1}
#-----------------------------------
def gode(y,x):
    y0=y[0] #Set initial conditions to D~1 at high z(1e5)(g=0) and deriv=0 at high z. Convert to g
    y1=y[1]
    z=1./x-1
    w=darkw(z)
    omdez=omegadez(z)
    y2=(-2*x*y1-(5.-3.*w*omdez)*x*y1-3*(1-w)*omdez*y0)/x**2/2.
    return [y1,y2]

#-----------------------------------
#FUNCTION: growth(z)
#Growth function 
#-----------------------------------
def growth(z):
    
    afact = 1./(1.+z)

    if((cosmo_params.w0==-1.) and (cosmo_params.wa==0.)):
        integralGr = integrate.quad(integrandGr,0.0,afact,epsrel=1.e-6,epsabs=0)
        growth = 5.0/2.0*cosmo_params.OmegaM*integralGr[0]
        growth = growth*math.sqrt(cosmo_params.OmegaM*(1.0+z)**3.+1.0-cosmo_params.OmegaM) #checked, ok
    else:
        
        #Solve numerically Eq 1 of Cooray+ 2004 (0304268)
        #2d^2g/dlna^2+[5-3w(a)Omega_DE(a)]*dg/dlna+3[1-w(a)]Omega-DE(a)g=0
        #and D=a*g
        #solve odeint:
        nsol=10000
        init=1.,0.  #initial are g(a)=1 at high z and g'(a)=0 at high z
        #python evaluates the initial condition at avec[0]
        avec=np.logspace(-10,math.log10(afact),nsol,base=10)
        sol=odeint(gode,init,avec)
        growth=sol[len(sol)-1,0]*afact
        #or approximate to 0.1% using
        #g(a) = exp[-\int {0,a} da/a[Omegam[a]^gamma-1]]
        #intgrowth = integrate.quad(intgrowthapprox,0,afact)
        #growth = np.exp(-intgrowth[0])
        
    return growth

#-----------------------------------
#FUNCTION: integrandDistance
#Integrand of the comoving radial distance
#-----------------------------------
def integrandDistance(z):
    integrandDistance=1./intHubble(z)
    return integrandDistance

#-----------------------------------
#FUNCTION: Distance(z)
#Comoving distance in Mpc
#-----------------------------------
def Distance(z):
    Distance = integrate.quad(integrandDistance,0.0,z,epsrel=1.e-6,epsabs=0)[0]
    return Distance

#--------------------------------
#FUNCTION: Tbbks(k)
#The BBKS transfer function
#--------------------------------
def Tbbks(k):

    q=k/(cosmo_params.h**2*cosmo_params.OmegaM)
    mfact = math.log(1+2.34*q)/(2.34*q)
    sq14 = 1+3.89*q+(16.1*q)**2+(5.46*q)**3+(6.71*q)**4
    Tbbks = mfact * sq14**(-0.25) 
    
    return Tbbks

#----------------------------------------------
# FUNCTION PSbbks(k)
# The BBKS power spectrum at z=0
# *up to some arbitrary normalization
#----------------------------------------------
def PSbbks(k):

    #Normalization from Peacock 91
    indn = cosmo_params.npower

    DELTAbbks = Tbbks(k)**2*k**(indn+3)
    PSbbks = DELTAbbks/k**3

    return PSbbks


#----------------------------------------------
#FUCNTION: intSIGbbks(lnk,R)
#The integrand needed to get sigma(R)
#----------------------------------------------
def intSIGbbks(lnk,R):

    k=np.exp(lnk)
    x=k*R
    wkr = 9./x**6*(np.sin(x)-x*np.cos(x))**2
    intSIGbbks = PSbbks(k)*k*k*k * wkr

    return intSIGbbks
    
#----------------------------------------------
#FUNCTION: sigmaMbbks(M,z)
#sigma(M,z) in BBKS
#----------------------------------------------

def sigmaMbbks(M,z):

    #M=4/3*pi*rho*R**3
    rhom0=cosmo_params.rhocrit0*cosmo_params.OmegaM
    rho=rhom0*(1+z)**3
    R = (3./4./math.pi*M/rho)**(1./3.)
    inttmp=integrate.quad(partial(intSIGbbks,R=R),np.log(glob.kminint),np.log(glob.kmaxint),epsabs=1.e-8)
    sigtmp=np.sqrt(0.5/math.pi/math.pi*inttmp[0])

    return sigtmp

#----------------------------------------------
#FUNCTION: sigmaRbbks(R,z)
#sigma(R,z) in BBKS
#----------------------------------------------

def sigmaRbbks(R,z):

    #M=4/3*pi*rho*R**3
    rhom0=cosmo_params.rhocrit0*cosmo_params.OmegaM
    rho=rhom0*(1+z)**3
    inttmp=integrate.quad(partial(intSIGbbks,R=R),np.log(glob.kminint),np.log(glob.kmaxint),epsabs=1.e-8)
    sigtmp=np.sqrt(0.5/math.pi/math.pi*inttmp[0])

    return sigtmp

#-----------------------------------------------
#FUNCTION: zofchi(x)
#The redshift corresponding to some comoving
#radial distance.
#-----------------------------------------------
def zofchi(x):

    if(x<=min(glob.chivec)): zofchi=min(glob.zvec)
    else:
        zofchi = glob.zofchi(x)

    return zofchi

#----------------------------------------------
#FUNCTION: dndztrunc(z,fdndz,zmin,zmax)
#A truncated dndz outside of [zmin,zmax]
#----------------------------------------------
def dndztrunc(z,fdndz,zmin,zmax):
    if(type(z) is list):
        y=[0.]*len(z)
        k=0
        for zval in z:
            y[k]=fdndz(zval)
            k=k+1
        return y
    else:
        if(z<zmin or z> zmax):
            return 0.
        else:
            y=fdndz(z)
            return y

#-----------------------------------------------
#FUNCTION: decomp_subW_1(zmin)
#Lensing kernel
#Decomposition of Wbar: first term
#-----------------------------------------------
def decomp_subW_1(zmin):

    retdc1 = scipy.integrate.romberg(glob.dndzfunc,zmin,glob.zmax,rtol=1.e-6,divmax=50)
    
    return retdc1

#-----------------------------------------------
#FUNCTION: integrand_subW2(x)
#Lensing kernel
#Integrand of the second Wbar term
#-----------------------------------------------
def integrand_subW2(x):

    z=zofchi(x)
    iretdc2 = glob.dndzfunc(z)/integrandDistance(z)/x
        
    return iretdc2

#-----------------------------------------------
#FUNCTION: decomp_subW_2(zmin)
#Lensing kernel
#Decomposition of Wbar: second term
#-----------------------------------------------
def decomp_subW_2(xmin):

    retdc2 = xmin*scipy.integrate.romberg(integrand_subW2,xmin,glob.chimax,rtol=1.e-6,divmax=50)
   
    return retdc2

#-----------------------------------------------
#FUNCTION: lenskernq(chi,fwbar)
#Lensing kernel q,
#uses interpolated Wbar(chi)
#-----------------------------------------------
def lenskernq(chi,fwbar):

    afact = 1./(1.+zofchi(chi))
    lk=chi/afact
    lk=lk*fwbar(chi)
    
    return lk

#------------------------------------------------------
#FUNCTION: dndz_task2(z,zmean)
#analytic dndz for the TJP code comparison
#Two Gaussians with sigma = 0.15, centered at z_1 = 1.0 and z_2 = 1.5.
#------------------------------------------------------
def dndz_task2(z,zmean):
    
    sigg=0.15
    if(type(z) is list):        
        gf=[0.]*len(z)
        for i in range(0,len(z)):
            gf[i]=np.exp(-0.5*(z[i]-zmean)*(z[i]-zmean)/sigg/sigg)/np.sqrt(2.*math.pi)/sigg
    else:
        gf = math.exp(-0.5*(z-zmean)*(z-zmean)/sigg/sigg)/math.sqrt(2.*math.pi)/sigg
        
    return gf

#------------------------------------------------------
#FUNCTION: nobias(z)
#Just a convenience function for non-evolving bias
#Currently returns a constant bias = 1
#------------------------------------------------------ 
def nobias(z):
    return 1.

#------------------------------------------------------
#FUNCTION: dndchi(dndz,chi)
#dN/dchi obtained from dN/dz and the jacobian of the
#chi->z transformation.
#------------------------------------------------------
def dndchi(dndz,chi):

    z=zofchi(chi)
    retval = dndz(z)/integrandDistance(chi)

    return retval

#------------------------------------------------------
#FUNCTION: integrand_Pclu(z,l,b1func,b2func,dndz1,dndz2)
#k-integrand tomographic clustering angular PS
#taking as input properties of the clustering tracers
#-----------------------------------------------------
def integrand_Pclu(z,l,b1func,b2func,dndz1,dndz2):

    chi=Distance(z)
    k = (l+0.5)/chi
    g = growth(z)/growth(0)
    ps = PSbbks(k)*g*g
    b1 = b1func(z)
    b2 = b2func(z)
    intdis = integrandDistance(z)
    retval = dndz1(z)*dndz2(z)*ps*b1*b2/chi/chi/intdis

    return retval

#------------------------------------------------------
#FUNCTION: integrand_Pggl(z,l,lk2,dndz1_b1func)
#k-integrand tomographic GGL angular PS
#taking as input properties of the clustering tracer
#and the lensing kernel
#-----------------------------------------------------
def integrand_Pggl(z,l,lk2,dndz1,b1func):

    chi=Distance(z)
    k = (l+0.5)/chi
    g = growth(z)/growth(0)
    ps = PSbbks(k)*g*g
    retval = b1func(z)*lenskernq(chi,lk2)*dndz1(z)*ps/chi/chi #OK

    return retval

#------------------------------------------------------
#FUNCTION: integrand_Pcon (z,l,lk1,lk2)
#k-integrand tomographic convergence angular PS
#taking as input the lensing kernels
#-----------------------------------------------------
def integrand_Pcon(z,l,lk1,lk2):

    chi=Distance(z)
    k = (l+0.5)/chi
    g = growth(z)/growth(0)
    ps = PSbbks(k)*g*g
    intdis = integrandDistance(z)
    retval = lenskernq(chi,lk1)*lenskernq(chi,lk2)*ps*intdis/chi/chi #OK

    return retval

#-----------------------------------------------
#FUNCTION: integrand_ACF_gg(l,theta,Pclu)
#Integrand to obtain the
#rojected clustering angular correlation function
#Takes ell, theta and C(l) angular
#clustering power spectrum as input
#----------------------------------------------
def integrand_ACF_gg(l,theta,Pclu):

    argbess=l*theta
    besselJ0=scipy.special.jn(0,argbess)
    retval = l/2./math.pi*besselJ0*Pclu(l)
    return retval


#----------------------------------------------
#FUNCTION: integrand_ACF_ggl(l,theta,Pggl)
#Integrand to obtain the
#projected GGL angular correlation function
#Takes ell, theta and C(l) angular
#power spectrum as input
#----------------------------------------------
def integrand_ACF_ggl(l,theta,Pggl):

    lensing_Cl = Pggl(l)
    retval = lensing_Cl*l/2./math.pi*scipy.special.jn(2,l*theta)

    return retval

#----------------------------------------------
#FUNCTION: integrand_ACF_plus(l,theta,Pcon)
#Integrand to obtain the
#projected xi_+ angular correlation function
#Takes ell, theta and C(l) angular
#power spectrum as input
#----------------------------------------------
def integrand_ACF_plus(l,theta,Pcon):

    lensing_Cl = Pcon(l)
    retval = lensing_Cl*l/2./math.pi*scipy.special.jn(0,l*theta)

    return retval

#----------------------------------------------
#FUNCTION: integrand_ACF_minus(l,theta,Pcon)
#Integrand to obtain the
#projected xi_- angular correlation function
#Takes ell, theta and C(l) angular
#power spectrum as input
#----------------------------------------------
def integrand_ACF_minus(l,theta,Pcon):

    lensing_Cl = Pcon(l)
    retval = lensing_Cl*l/2./math.pi*scipy.special.jn(4,l*theta)

    return retval
