#! /usr/bin/env python
import numpy as np
import json
import pandas as pd
import benchml
import csv
log = benchml.log

def csv_to_xyz(csv_file, args, xyz_file=""):
    log << log.mg << "Converting csv to extxyz format" << log.endl
    reader = csv.DictReader(open(csv_file))
    data = [ row for row in reader ]
    smiles = [ row[args.smiles] for row in data ]
    if args.no_gen3d:
        configs = list(smiles_to_pseudo_xyz(smiles))
    else:
        configs = smiles_to_xyz(smiles, args.corina, args.babel)
    assert len(configs) == len(data)
    for config, meta in zip(configs, data):
        config.info = meta
        config.info["csv_fieldnames"] = reader.fieldnames
    if xyz_file != "":
        log << "Writing structures to" << xyz_file << log.endl
        benchml.write(xyz_file, configs)
    return configs

def smiles_to_pseudo_xyz(smiles):
    import rdkit.Chem as Chem
    for smi in smiles:
        mol = Chem.MolFromSmiles(smi)
        symbols = [ a.GetSymbol() for a in mol.GetAtoms() ]
        pos = np.zeros((len(symbols),3))
        config = benchml.readwrite.ExtendedXyz(pos=pos, symbols=symbols)
        yield config

def smiles_to_xyz(smiles, corina, babel):
    log >> 'mkdir -p tmp'
    log >> 'rm -f tmp/tmp.smi tmp/tmp.xyz tmp/tmp.sdf'
    with open('tmp/tmp.smi', 'w') as ofs:
        for s in smiles: ofs.write('%s\n' % s)
    log >> '%s -d wh -i t=smiles tmp/tmp.smi tmp/tmp.sdf' % corina
    log >> '%s -isdf tmp/tmp.sdf -oxyz tmp/tmp.xyz' % babel
    configs = benchml.read('tmp/tmp.xyz')
    assert len(configs) == len(smiles)
    log >> 'rm -f tmp/tmp.smi tmp/tmp.xyz tmp/tmp.sdf'
    log >> 'mv corina.trc tmp/.'
    return configs

def xyz_to_csv(xyz_file, csv_file):
    configs = benchml.read(xyz_file)
    if len(configs) < 1: return
    fields = configs[0].info["csv_fieldnames"]
    for field in sorted(configs[0].info.keys()):
        if not field in fields: fields.append(field)
    if "csv_fieldnames" in fields: fields.remove("csv_fieldnames")
    writer = csv.DictWriter(open(csv_file, "w"), fieldnames=fields)
    writer.writeheader()
    for config in configs:
        config.info.pop("csv_fieldnames")
        writer.writerow(config.info)

def validate_args(args):
    if args.input == "" and args.from_csv == "":
        raise ValueError("Need to specify either --input or --from_csv")
    if args.to_csv == "":
        if args.output == "":
            raise ValueError("Require --output")
    else:
        if args.input == "":
            raise ValueError("Require --input")
    return args

if __name__ == "__main__":
    log.Connect()
    log.AddArg("input", str, default="")
    log.AddArg("filter", str, default="")
    log.AddArg("output", str, default="")
    # CSV <> XYZ conversion
    log.AddArg("from_csv", str, default="")
    log.AddArg("to_csv", str, default="")
    log.AddArg("no_gen3d", "toggle", default=True)
    log.AddArg("smiles", str, default="SMILES")
    log.AddArg("corina", str, default='/software/corina/bin/corina')
    log.AddArg("babel", str, default='/software/babel/bin/babel')
    args = validate_args(log.Parse())

    configs = None
    if args.to_csv != "":
        xyz_to_csv(args.input, args.to_csv)        
        log.okquit()
    if args.from_csv != "":
        configs = csv_to_xyz(args.from_csv, args)
    else:
        log << "Reading" << args.input << log.endl
        configs = benchml.read(args.input)
    if args.filter != "":
        log << "Applying filter:" << args.filter << log.endl
        fct = eval(args.filter)
        log << "Reduced samples from" << len(configs) << log.flush
        configs = list(filter(fct, configs))
        log << "to" << len(configs) << log.endl
    benchml.write(args.output, configs) 

