#!/usr/bin/env python
import mdtraj as md
import numpy as np
import sys
import multiprocessing as mp
from optparse import OptionParser
from optparse import OptionGroup

# keyboard inputs
def parse():
    parser = OptionParser()
    usage = "usage: %prog [options] arg1 arg2"

    parser.add_option("-p", "--parmtop", type="str",
                      help="parameter/topology file [compulsory argument 1] (.prmtop/.psf/.gro/.pdb or any topology format compatible with MDTRAJ python module)",
                      dest="p", default='nil')
    parser.add_option("-f", "--trajectory", type="str",
                      help="coordinates/trajectory file [compulsory argument 2] (.nc/.xtc/.dcd/.gro/.pdb any coordinate format compatible with MDTRAJ python module)",
                      dest="f", default='nil')

    group = OptionGroup(parser, "Suggested Options",
                    "Note: choosing a higher number of processors when availabe can speed up the calculation.")

    group.add_option("-n", "--nproc", type="int",
                      help="number of processors to run the script (default: 1)",
                      dest="n", default=1)
    group.add_option("-a", "--cutpercent", type="float",
                      help="cut-off for percentage of frames for defining contacts [percentage without % symbol] (default: 75)",
                      dest="a", default=75)
    group.add_option("-c", "--cutfoffdist", type="float",
                      help="heavy atom cut-off distance for calculating contacts [in angstrom] (default: 4.5)",
                      dest="c", default=4.5)
    group.add_option("-x", "--contactname", type="str",
                      help="output file with all residue names [serial format] (default: contactResNames.dat)",
                      dest="x", default='contactResNames.dat')
    group.add_option("-y", "--contactmatrixframes", type="str",
                      help="output file of matrix with number of frames in cotact (default: contactMatrixResFrames.dat)",
                      dest="y", default='contactMatrixResFrames.dat')
    group.add_option("-z", "--contactmatrixfraction", type="str",
                      help="output file of matrix with number of frames in contcat [weighted adjacency matrix] (default: contactMatrixFraction.dat)",
                      dest="z", default='contactMatrixFraction.dat')
    group.add_option("-o", "--contact", type="str",
                      help="output file of matrix satifying given percent cutoff [unweighted adjacency matrix] (default: contact.dat)",
                      dest="o", default='contact.dat')
    parser.add_option_group(group)

    options, arguments = parser.parse_args()
    return options

"""
pairs_dist() lists all the heavy atom pairs possible for 2 residues in all 
frames. Length of the return array is number of frames.
"""
def pairs_dist(topo,traj,res1,res2):
    l = []
    for i in res1.atoms:
         if (i.element.name != 'hydrogen'):
             for j in res2.atoms:
                 if (j.element.name != 'hydrogen'):
                     l.append([i.index,j.index])
    atom_pairs = np.array(l)
    all_pairdist = md.compute_distances(traj, atom_pairs, periodic = True, opt = True)
    return all_pairdist

"""
cut_dist() look for all pairwise distances within a cut off and count the number
of contact satidfying the conditon
"""
def cut_dist(topo,traj,res1,res2, cut):
    all_pairdist = pairs_dist(topo,traj,res1,res2)
    each_frame_min = np.min(all_pairdist, axis = 1)
    negtive_each_frame_min = -1*each_frame_min
    negative_cut = -1*cut
    binary = np.digitize(negtive_each_frame_min,bins=[negative_cut])
    f = np.sum(binary)
    return f
         
def iloop(i,topo,traj,upto,n_apart,Norm,cut):
        res1 = topo.residue(i)
        j = (i+n_apart)
        out = []
        while j < upto:
            res2 = topo.residue(j)
            f = cut_dist(topo,traj,res1,res2,cut)
            res_i_index = i
            res_j_index = j
            out.append([res_i_index,res1,res_j_index,res2,f])
            j+=1
        return out

"""
find_weight() considers a pair of residues at a time, and call cut_dist
then identify al contacts that satify 70% criteria
"""
def find_weight(topo,traj,upto,n_apart,nproc,cut,Norm):
    pool = mp.Pool(nproc)
    results = []
    results = pool.starmap(iloop, [(i,topo,traj,upto,n_apart,Norm,cut) for i in range(0, (upto-n_apart))])
    return results

def main():
    options = parse()
    print ('Reading trajectory....')

    f_in1 = str(options.p)
    f_in2 = str(options.f)

    traj = md.load(f_in2, top=f_in1)
    topo = traj.topology

    cut = options.c/10.0 # converted to nanometer
    cutpercent = options.a/100.0
    nproc = options.n
 
    f1name = str(options.x)
    f3name = str(options.y)
    f4name = str(options.z)
    f5name = str(options.o)

    f1 = open(f1name, 'w')
    f2 = open('contact_num_only.dat', 'w')
    print ('#res1_index   res1_index    res1    res2     n_frames_contact', file = f1)

    Norm = traj.n_frames*1.0
    upto = topo.n_residues
    n_apart = 1
    results = find_weight(topo,traj,upto,n_apart,nproc,cut,Norm)
    cmat = np.zeros([upto,upto])
    for i in results:
         for j in i:
             fract = j[4]/Norm
             print(j[0],j[1],j[2],j[3],j[4],round(fract,2),file=f1)
             print(j[0],j[2],j[4],round(fract,2),file=f2)
             res1_index = int(j[0])
             res2_index = int(j[2])
             count = j[4]
             cmat[res1_index][res2_index] = count
             cmat[res2_index][res1_index] = count

    cmat_fract = cmat/Norm
    cmat_adj = np.digitize(cmat_fract,bins=[cutpercent])

    np.savetxt(f3name,cmat,fmt='%d')
    np.savetxt(f4name,cmat_fract,fmt='%.2f')
    np.savetxt(f5name,cmat_adj,fmt='%d')

    f1.close()
    f2.close()

if __name__ == '__main__':
    main()

