#!/usr/bin/env python3
#
# glimps-backmap: backmap snapshots in a coarse-grained trajectory to a
#                 fine-grained representation.
#
#                 Invoke as:
#
#                   glimps-backmap <cg_traj> <fg_traj> <molspecs>
#
#                 where <cg_traj> (existing) and <fg_traj> (created) are in 
#                 any MDTraj-supported format and <molspecs> is a molecule 
#                 specification file (use "glimps-backmap -h" to see additional 
#                 options)
#
#                 The molecule specification file should contain rows with the
#                 format:
#
#                   <molid> <n_copies>
#
#                 where:
#                   <molid>:    name of component (matching that used in training)
#                   <n_copies>: copies of this component in the trajectories
#
#                 The order of the rows must match the order of the components
#                 in the trajecory file.
#


import argparse
import pickle
import json
import os.path as op
try:
    import mdtraj as mdt
except ImportError:
    print('You need to install MDTraj ("pip install mdtraj") to use glimps-backmap')
    exit(1)
import numpy as np
from mdplus.refinement import fix_bumps
from mdplus.multiscale import Glimps
try:
    import pandas as pd
except ImportError:
    print('You need to install Pandas ("pip install pandas") to use glimps-backmap')
    exit(1)

def map(args):
    """
    Backmap CG -> FG
    """
    print('loading input structure')
    if args.top:
        intraj = mdt.load(args.intraj, top=args.top)
    else:
        intraj = mdt.load(args.intraj)
    n_frames = intraj.n_frames

    molspecs = {}
    with open(args.molspec) as f:
        for line in f.readlines():
            if line[0] != "#":
                words = line.split()
                if len(words) > 1:
                    molspec = {}
                    molid = words[0]
                    molspec["n_copies"] = int(words[1])
                    molspecs[molid] = molspec
    fg_top = None
    seqoff = 0
    for molid in molspecs:
        seroff = 0
        print('generating topology for {}'.format(molid))
        molspec = molspecs[molid]
        pklfile = op.join(args.datadir, molid + '.json')
        if op.exists(pklfile):
            with open(pklfile) as f:
                data = json.load(f)
                transformer = Glimps()
                transformer.set_state(data['transformer'])
                data['transformer'] = transformer
                data['cg_topology'] = pd.DataFrame.from_dict(data['cg_topology'])
                data['fg_topology'] = pd.DataFrame.from_dict(data['fg_topology'])
        else:
            pklfile = op.join(args.datadir, molid + '.pkl')
            if not op.exists(pklfile):
                raise ValueError('Error: no mapper for molecule {} in the database folder'.format(molid))
            with open(pklfile, "rb") as f:
                data = pickle.load(f)

        cg_d = data["cg_topology"]
        fg_d = data["fg_topology"]
        molspec["transformer"] = data["transformer"]
        molspec["n_fg"] = len(fg_d)
        molspec["n_cg"] = len(cg_d)
        topologies = [fg_d.copy() for i in range(molspec["n_copies"])]
        for t in topologies:
            t.loc[:, 'resSeq'] += seqoff
            t.loc[:, 'chainID'] += seqoff
            seqoff += 1
            t.loc[:, 'serial'] += seroff
            seroff += len(t)
        fulltop = pd.concat(topologies, ignore_index=True)
        fg_mol_top = mdt.Topology().from_dataframe(fulltop, None)
        
        if fg_top is None:
            fg_top = fg_mol_top
        else:
            fg_top = fg_top.join(fg_mol_top)
            
    for i, r in enumerate(fg_top.residues):
        r.resSeq = i + 1
    for i, a in enumerate(fg_top.atoms):
        a.serial = i + 1

    outxyz = np.zeros((n_frames, fg_top.n_atoms, 3), dtype=np.float32)
    j_fg = 0
    j_cg = 0
    for molid in molspecs:
        print('transforming {}'.format(molid))
        molspec = molspecs[molid]
        i_fg = j_fg 
        j_fg = i_fg + molspec["n_fg"] * molspec["n_copies"]
        i_cg = j_cg
        j_cg = i_cg + molspec["n_cg"] * molspec["n_copies"]

        x_cg = intraj.xyz[:, i_cg:j_cg].reshape((n_frames * molspec["n_copies"], molspec["n_cg"], 3))
        x_fg = molspec["transformer"].transform(x_cg).reshape((n_frames, -1, 3))
        outxyz[:, i_fg:j_fg] = x_fg

    print('fixing bumps')
    outxyz = fix_bumps(outxyz, 0.1, 0.15)
    print('writing output')
    outtraj = mdt.Trajectory(outxyz, fg_top, time=intraj.time,
                             unitcell_lengths=intraj.unitcell_lengths,
                             unitcell_angles=intraj.unitcell_angles)
    outtraj.save(args.outtraj)

parser = argparse.ArgumentParser()
parser.add_argument('intraj', help='input trajectory file')
parser.add_argument('outtraj', help='output trajectory file')
parser.add_argument('molspec', help='molecule specification file')
parser.add_argument('--top', help='input topology file')
parser.add_argument('--datadir', default='.', help='directory with .pkl files')

args = parser.parse_args()
map(args)
