#!/usr/bin/env python
import numpy as np
import pandas as pd
import reciprocalspaceship as rs
from reciprocalspaceship.algorithms.scale_merged_intensities import mean_intensity_by_resolution
from argparse import ArgumentParser
from os import environ


parser = ArgumentParser()
parser.add_argument("off_mtz")
parser.add_argument("on_mtz")
parser.add_argument("phase_mtz")
parser.add_argument("-o", help='output mtz filename', default='difference.mtz')
parser.add_argument("--embed", action='store_true')
parser.add_argument("--positive", action='store_true', help='convert difference structurefactors to be strictly postive')
parser = parser.parse_args()


phase_key = 'PHIF-model'
phases_filename = parser.phase_mtz
off_mtz_filename = parser.off_mtz
on_mtz_filename = parser.on_mtz
outFN = parser.o




ds = rs.read_mtz(phases_filename).dropna()

off = rs.read_mtz(off_mtz_filename)
on  = rs.read_mtz(on_mtz_filename)

if (off.dtypes == 'G').sum():
    off = off.stack_anomalous()
if (on.dtypes == 'G').sum():
    on = on.stack_anomalous()
if (ds.dtypes == 'G').sum():
    ds = ds.stack_anomalous()

if phase_key not in ds:
    phase_key = ds.dtypes[ds.dtypes=='P'].keys()[0]

idx = off.index.intersection(on.index).intersection(ds.index)

off = off.loc[idx]
on  =  on.loc[idx]
ds  = ds.loc[idx]



output = rs.DataSet({
    'Foff' : off['F'],
    'SigFoff' : off['SigF'],
    'Fon' : on['F'],
    'SigFon' : on['SigF'],
    'PhiFoff' : ds[phase_key],
    'F-model' : ds['F-model'],
})

output.cell = ds.cell
output.spacegroup = ds.spacegroup


output['DeltaF'] = (output['Fon'] - output['Foff']).astype('F')
output['SigDeltaF'] = np.sqrt(output['SigFon']**2. + output['SigFoff']**2.).astype('Q')


if parser.positive:
    cmplx = rs.utils.to_structurefactor(output.DeltaF, output.PhiFoff)
    output['DeltaF'] = np.abs(cmplx)
    output['PhiDeltaF'] = np.rad2deg(np.angle(cmplx))
    output['DeltaF'] = output.DeltaF.astype('F')
    output['PhiDeltaF'] = output.DeltaF.astype('P')


output['W'] = ((1 + output.SigDeltaF**2./np.mean(output.SigDeltaF**2.))**-1.).astype('W')
output['DH'] = ((1 + output.SigDeltaF**2./np.mean(output.SigDeltaF**2.) + 0.05*output.DeltaF**2./np.mean(output.DeltaF**2.))**-1.).astype('W')
output['ML'] = (output.SigDeltaF**-2.).astype('W')

from scipy.stats import gaussian_kde

#X = output[['DeltaF', 'SigDeltaF']].to_numpy(np.float).T
#X = (X - X.mean(1)[:,None]) / X.std(1)[:,None]
#k = gaussian_kde(X)
#output['KDE'] = rs.DataSeries(k(X), dtype='W', index=output.index)
#output['LOGKDE'] = rs.DataSeries(k.logpdf(X), dtype='W', index=output.index)
from reciprocalspaceship.algorithms.scale_merged_intensities import mean_intensity_by_resolution
from scipy.stats import norm

d = output.compute_dHKL().dHKL.to_numpy()
mean = mean_intensity_by_resolution(output.Foff, d)
std  = np.sqrt(mean_intensity_by_resolution((output.Foff - mean)**2., d))
poff = norm(mean, std).pdf(output.Foff)

mean = mean_intensity_by_resolution(output.Fon, d)
std  = np.sqrt(mean_intensity_by_resolution((output.Fon - mean)**2., d))
pon  = norm(mean, std).pdf(output.Fon)

output['RNORM'] = rs.DataSeries(pon*poff, dtype='W', index=output.index)

#X = output.compute_dHKL()[['Foff', 'SigFoff', 'Fon', 'SigFon', 'dHKL']].to_numpy(np.float).T
X = output.compute_dHKL()[['Foff', 'Fon', 'dHKL']].to_numpy(np.float).T
#X = output.compute_dHKL()[['SigFoff', 'SigFon', 'dHKL']].to_numpy(np.float).T
X = (X - X.mean(1)[:,None]) / X.std(1)[:,None]
k = gaussian_kde(X)
output['KDE'] = rs.DataSeries(k(X), dtype='W', index=output.index)

output['ScaledDeltaF'] = output['DeltaF'] * mean_intensity_by_resolution(output['F-model'], output.compute_dHKL().dHKL) / mean_intensity_by_resolution(output['Foff'], output.compute_dHKL().dHKL)

output.infer_mtz_dtypes().write_mtz(outFN)

if parser.embed:
    from IPython import embed
    from pylab import *
    embed(colors='Linux')
