#!/usr/bin/env python3
#
# glimps-map: map snapshots in a fine-grained trajectory to a
#             coarse-grained representation.
#
#             Invoke as:
#
#               glimps-map <fg_traj> <cg_traj> <molspecs>
#
#             where <fg_traj> (existing) and <cg_traj> (created) are in 
#             any MDTraj-supported format and <molspecs> is a molecule 
#             specification file (use "glimps-map -h" to see additional 
#             options)
#
#             The molecule specification file should contain rows with the
#             format:
#
#               <molid> <n_copies>
#
#             where:
#               <molid>:    name of component (matching that used in training)
#               <n_copies>: copies of this component in the trajectories
#
#             The order of the rows must match the order of the components
#             in the trajecory file.
#


import argparse
import pickle
import os.path as op
try:
    import mdtraj as mdt
except ImportError:
    print('You need to install MDTraj ("pip install mdtraj") to use glimps-map')
    exit(1)
import numpy as np
from mdplus.refinement import fix_bumps
from mdplus.multiscale import Glimps

def map(args):
    """
    Forward mapping (FG -> CG)
    """
    print('loading input structure')
    if args.top:
        intraj = mdt.load(args.intraj, top=args.top)
    else:
        intraj = mdt.load(args.intraj)
    n_frames = intraj.n_frames

    molspecs = {}
    with open(args.molspec) as f:
        for line in f.readlines():
            if line[0] != "#":
                words = line.split()
                if len(words) > 1:
                    molspec = {}
                    molid = words[0]
                    molspec["n_copies"] = int(words[1])
                    molspecs[molid] = molspec
    cg_top = None
    for molid in molspecs:
        print('generating topology for {}'.format(molid))
        molspec = molspecs[molid]
        pklfile = op.join(args.datadir, molid + '.json') 
        if op.exists(pklfile):
            with open(pklfile) as f:
                data = json.load(f)
                transformer = Glimps()
                data['transformer'] = transformer.set_state[data['transformer']
                data['cg_topology'] = pd.DataFrame.from_dict(data['cg_topology'])
                data['fg_topology'] = pd.DataFrame.from_dict(data['fg_topology'])
        else::
            pklfile = op.join(args.datadir, molid + '.pkl') 
            if not op.exists(pklfile):
                raise ValueError('Error: no mapper for molecule {} in the database folder'.format(molid))
        with open(pklfile, "rb") as f:
            data = pickle.load(f)

        cg_t = mdt.Topology().from_dataframe(data["cg_topology"])
        fg_t = mdt.Topology().from_dataframe(data["fg_topology"])
        molspec["transformer"] = data["transformer"]
        molspec["n_fg"] = fg_t.n_atoms
        molspec["n_cg"] = cg_t.n_atoms
        cg_mol_top = cg_t.copy()
        if molspec["n_copies"] > 1:
            for i in range(molspec["n_copies"] - 1):
                cg_mol_top = cg_mol_top.join(cg_t)
        if cg_top is None:
            cg_top = cg_mol_top
        else:
            cg_top = cg_top.join(cg_mol_top)
            
    for i, r in enumerate(cg_top.residues):
        r.resSeq = i + 1

    outxyz = np.zeros((n_frames, cg_top.n_atoms, 3), dtype=np.float32)
    j_fg = 0
    j_cg = 0
    for molid in molspecs:
        print('transforming {}'.format(molid))
        molspec = molspecs[molid]
        i_fg = j_fg 
        j_fg = i_fg + molspec["n_fg"] * molspec["n_copies"]
        i_cg = j_cg
        j_cg = i_cg + molspec["n_cg"] * molspec["n_copies"]

        x_fg = intraj.xyz[:, i_fg:j_fg].reshape((n_frames * molspec["n_copies"], molspec["n_fg"], 3))
        x_cg = molspec["transformer"].inverse_transform(x_fg).reshape((n_frames, -1, 3))
        outxyz[:, i_cg:j_cg] = x_cg

    print('writing output')
    outtraj = mdt.Trajectory(outxyz, cg_top, time=intraj.time,
                             unitcell_lengths=intraj.unitcell_lengths,
                             unitcell_angles=intraj.unitcell_angles)
    outtraj.save(args.outtraj)

parser = argparse.ArgumentParser()
parser.add_argument('intraj', help='input trajectory file')
parser.add_argument('outtraj', help='output trajectory file')
parser.add_argument('molspec', help='molecule specification file')
parser.add_argument('--top', help='input topology file')
parser.add_argument('--datadir', default='.', help='directory with .pkl files')

args = parser.parse_args()
map(args)
