
#-----------------------------------------------
#
#Obtain angular power spectra and correlation
#functions for galaxy-galaxy lensing.
#
#-----------------------------------------------

import sys,math,numpy as np,scipy
from scipy import interpolate
from scipy import integrate
from scipy.integrate import odeint
from functools import partial
import cosmo_params,glob
import lsst_func
import matplotlib.pyplot as plt
import datetime

glob.kmin=1.e-3*cosmo_params.h
glob.kmax=10.*cosmo_params.h
glob.kminint=glob.kmin
glob.kmaxint=glob.kmax*1e10

#Change this to yours:
ccldir='../lsst_ccl/CCL/'

#Choose the redshift distr mode.
strmode='analytic'
#strmode='histo'

#Some numbers
nl=501
nw=1501

#Get the actual theta bins from the benchmarks (in radians)
thetameans=np.loadtxt(ccldir+'tests/benchmark/codecomp_step2_outputs/run_b2b2analytic_log_wt_ll_pp.txt',usecols=(0))/180.*math.pi

print("Mode: "+strmode)

#Figure out the correct BBKS normalization
#for the power spectrum for this cosmology
rsig8=8.0/cosmo_params.h
sigR8=lsst_func.sigmaRbbks(rsig8,0.)
bbksnorm = 1./cosmo_params.sigma8*sigR8
glob.psnorm = bbksnorm

#Construct some other quantities I need
#and populate the global variables

#Reshift of distance
glob.zvec=np.linspace(0.,30,1001)
i=0
glob.chivec=[0.]*len(glob.zvec)
for z in glob.zvec:
    glob.chivec[i]=lsst_func.Distance(z)
    i+=1
        
lvec = np.logspace(0,5,nl)
sigg=0.15
    
if(strmode=='analytic'):
    dndz_bin1 = partial(lsst_func.dndz_task2,zmean=1.)
    dndz_bin2 = partial(lsst_func.dndz_task2,zmean=1.5)
    zmax=3.
    zmin=0.
    zminlens=1.e-4
else:
    #binned redshift distribution
    data=np.loadtxt('../lsst_forecast/elisa/z_DESC-CC')
    zbin = data[:,0]
    dndz1 = data[:,1]
    dndz2 = data[:,2]
    #Extend with zeros
    dz=zbin[1]-zbin[0]
    nextend=int(round((min(zbin)-dz)/dz))
    zextend=np.linspace(0.,min(zbin)-dz,nextend+1)
    zextend2=np.linspace(max(zbin)+dz,max(zbin)+dz+nextend*dz,nextend+1)
    dndzextend=[0.]*(nextend+1)
    dndz_tmp1 = scipy.interpolate.interp1d(np.concatenate((zextend,zbin,zextend2)),np.concatenate((dndzextend,dndz1,dndzextend)),kind='cubic')
    dndz_tmp2 = scipy.interpolate.interp1d(np.concatenate((zextend,zbin,zextend2)),np.concatenate((dndzextend,dndz2,dndzextend)),kind='cubic')
    zmax=max(zbin)
    zmin=0.
    zminlens=1.e-4
    dndz_bin1 = partial(lsst_func.dndztrunc,fdndz=dndz_tmp1,zmin=zmin,zmax=zmax)
    dndz_bin2 = partial(lsst_func.dndztrunc,fdndz=dndz_tmp2,zmin=zmin,zmax=zmax)

glob.zmax=zmax


#********************************************
#Lensing kernels for power spectra
#********************************************
print("Starting lensing power spectrum")

#REdshift bin combinations: (1,1),(1,2),(2,2) 
#And now compute lensing kernel
    
wbar=np.zeros((2,nw))
lenskern12=np.zeros((2,nw))
newzvec=np.linspace(zminlens,zmax,nw)
newchivec=[0.]*nw
i=0
for z in newzvec:
    newchivec[i]=lsst_func.Distance(z)
    i+=1
i=0
glob.zvec=newzvec
glob.chivec=newchivec
glob.dinterp=scipy.interpolate.interp1d(glob.zvec,glob.chivec,kind='cubic')
glob.zofchi=scipy.interpolate.interp1d(glob.chivec,glob.zvec,kind='cubic')
glob.chimax=max(newchivec)

wbar1=[0.]*nw
wbar2=[0.]*nw
glob.dndzfunc=dndz_bin1
for j in range(0,nw):
    if(j%100==0): print(j,"of",nw)
    wbar1[j]=lsst_func.decomp_subW_1(newzvec[j])
    wbar2[j]=lsst_func.decomp_subW_2(newchivec[j])
    wbar[0,j]=wbar1[j]-wbar2[j]
glob.dndzfunc=dndz_bin2
for j in range(0,nw):
    if(j%100==0): print(j,"of",nw)
    wbar1[j]=lsst_func.decomp_subW_1(newzvec[j])
    wbar2[j]=lsst_func.decomp_subW_2(newchivec[j])
    wbar[1,j]=wbar1[j]-wbar2[j]
print("bar(W) function needed for lensing kernel done.")
    
fw1=scipy.interpolate.interp1d(glob.chivec,wbar[0,:],kind='cubic')
fw2=scipy.interpolate.interp1d(glob.chivec,wbar[1,:],kind='cubic')

j=0
for chi in glob.chivec:
    lenskern12[0,j]=lsst_func.lenskernq(glob.chivec[j],fw1)
    lenskern12[1,j]=lsst_func.lenskernq(glob.chivec[j],fw2)
    j+=1
print("Lensing kernels done")

j=0
#Lensing kenrnel normalization
knorm=3./2.*cosmo_params.H0**2/cosmo_params.c**2*cosmo_params.OmegaM
    
#********************************************
#Tomographic GGL power spectrum
#********************************************
print("Starting GGL power spectrum")
#Only (1,2) combination
nmaxl=60001
file = open('run_b1b2'+strmode+'_log_cl_dl.txt','w')
i=0
i12=0.0
powerggl = np.zeros(nmaxl)
file.write(str(i)+' '+str(i12)+'\n')
file.write(str(i)+' '+str(i12)+'\n')
for i in range(2,3001):
    if(i%1000==0): print(i,"of",nmaxl)
    partialPggl2 = partial(lsst_func.integrand_Pggl,l=i,lk2=fw2,dndz1=dndz_bin1,b1func=lsst_func.nobias)
    i12 = integrate.quad(partialPggl2,zmin,zmax,epsrel=1.e-7,epsabs=0,limit=1000)[0]*knorm/glob.psnorm**2
    file.write(str(i)+' '+str(i12)+'\n')
    powerggl[i]=i12
file.close()

if(strmode=='histo'):
    exit()

for i in range(3001,nmaxl):
    if(i%1000==0): print(i,"of",nmaxl)
    partialPggl2 = partial(lsst_func.integrand_Pggl,l=i,lk2=fw2,dndz1=dndz_bin1,b1func=lsst_func.nobias)
    i12 = integrate.quad(partialPggl2,zmin,zmax,epsrel=1.e-7,epsabs=0,limit=1000)[0]*knorm/glob.psnorm**2
    powerggl[i]=i12
    
interpPggl12 = scipy.interpolate.interp1d(np.arange(0,nmaxl),powerggl,kind='cubic')
    
    
#********************************************************************
#GGL correlation function
#********************************************************************
print("Starting lensing wggl")
#Only (1,2)
    
wggl = np.zeros(len(thetameans))
xmax = lsst_func.Distance(zmax)
file = open('run_b1b2'+strmode+'_log_wt_dl.txt','w')

for i in range(0,len(thetameans)):
    
    partialACF = partial(lsst_func.integrand_ACF_ggl,theta=thetameans[i],Pggl=interpPggl12)
    tmp = integrate.quad(partialACF,0,6e4,epsrel=1.e-4,epsabs=0,limit=2000)[0]
    wggl[i] = tmp
    
    file.write(str(thetameans[i]*180./math.pi)+' '+str(wggl[i])+'\n')

file.close()


exit()
