#!/usr/bin/env python

"""
Plot CChalf vs resolution. 

Usage
-----
ccplot <half1_dataset_filename> <half2_dataset_filename>

"""
import reciprocalspaceship as rs
import numpy as np
from careless.io.formatter import get_first_key_of_dtype
from matplotlib import pyplot as plt
import matplotlib as mpl
from sys import argv


mpl.rcParams['font.size']=14

half1_filename = argv[1]
half2_filename = argv[2]


half1 = rs.read_mtz(half1_filename).compute_dHKL().dropna()
half2 = rs.read_mtz(half2_filename).compute_dHKL().dropna()


if (half1.dtypes == 'K').any() or (half1.dtypes == 'G').any():
    half1 = half1.stack_anomalous()
    half1 = half1[half1.N > 0]
if (half2.dtypes == 'K').any() or (half2.dtypes == 'G').any():
    half2 = half2.stack_anomalous()
    half2 = half2[half2.N > 0]

overrides = [
    'DeltaF', 
]

def get_isigi_keys(ds, dtype_preferences=None):
    ikey = None
    if dtype_preferences is None:
        dtype_preferences = ['J', 'F']
    for dtype in dtype_preferences:
        ikey = get_first_key_of_dtype(ds, dtype)
        if ikey is not None:
            break
    for k in overrides:
        if k in ds:
            ikey = k

    if ikey is None:
        raise KeyError("No compatible keys found")

    print(f"choosing intensity key {ikey} ...")

    if f"Sig{ikey}" in ds:
        sigkey = f"Sig{ikey}"
    elif f"SIG{ikey}" in ds:
        sigkey = f"SIG{ikey}"
    else:
        raise KeyError(f"No matching std deviation key for intensity key {ikey}")

    return ikey, sigkey

ikey1,sigkey1 = get_isigi_keys(half1)
ikey2,sigkey2 = get_isigi_keys(half2)

nbins = 20

idx = half1.index.intersection(half2.index)
half1 = half1.loc[idx]
half2 = half2.loc[idx]

bins = np.percentile(half1.dHKL, np.linspace(100, 0, nbins + 1))
bins = np.vstack((bins[:-1], bins[1:]))
colors = np.zeros(len(half1))

cc = []
ticks = []
from tqdm import tqdm
for binnumber, (i,j) in tqdm(enumerate(zip(*bins))):
    idx = (half1.dHKL >= j) & (half1.dHKL <= i)
    idx = idx.astype(bool)
    colors[idx] = binnumber / nbins
    cc.append(np.corrcoef(half1[idx][ikey1], half2[idx][ikey2])[0,1])
    ticks.append(f'{i:0.2f}-{j:0.2f}')
plt.ylim(0., 1.)

plt.grid(linestyle='-.')

plt.plot(cc, '-k')
plt.scatter(np.arange(nbins), cc, c=np.linspace(0, 1, nbins), s=100)
plt.title(r"$|F|$ Pearson Correlation")

plt.xticks(np.arange(nbins), ticks, rotation=45, ha='right', rotation_mode='anchor')
plt.xlabel(r"$Resolution\ (\AA)$")
plt.ylabel(r"$CC$")
plt.tight_layout()

plt.figure()
cmap = plt.get_cmap()
plt.errorbar(
    half1[ikey1].to_numpy('float32'),
    half2[ikey2].to_numpy('float32'),
    xerr=half1[sigkey1].to_numpy('float32'),
    yerr=half2[sigkey2].to_numpy('float32'),
    ecolor=cmap(colors),
    ls='none',
    alpha=0.2,
)
plt.xlabel(f"$|{ikey1}_1|$")
plt.ylabel(f"$|{ikey2}_2|$")
plt.show()

#from IPython import embed
#embed(colors='Linux')
