#!/tmp_mnt/filer1/unix/cpoelking/nomad/ccdc_rev/venv/bin/python
import csv

import numpy as np

import benchml

log = benchml.log


def csv_to_xyz(csv_file, args, xyz_file=""):
    log << log.mg << "Converting csv to extxyz format" << log.endl
    reader = csv.DictReader(open(csv_file))
    data = [row for row in reader]
    smiles = [row[args.smiles_from] for row in data]
    if not args.gen3d:
        configs = list(smiles_to_pseudo_xyz(smiles))
    else:
        configs = smiles_to_xyz(smiles, args.corina, args.babel)
    assert len(configs) == len(data)
    for config, meta in zip(configs, data):
        config.info = meta
        config.info["csv_fieldnames"] = reader.fieldnames
    if xyz_file != "":
        log << "Writing structures to" << xyz_file << log.endl
        benchml.write(xyz_file, configs)
    return configs


def smiles_to_pseudo_xyz(smiles):
    import rdkit.Chem as Chem

    for smi in smiles:
        mol = Chem.MolFromSmiles(smi)
        symbols = [a.GetSymbol() for a in mol.GetAtoms()]
        pos = np.zeros((len(symbols), 3))
        config = benchml.readwrite.ExtendedXyz(pos=pos, symbols=symbols)
        yield config


def smiles_to_xyz(smiles, corina, babel):
    log >> "mkdir -p tmp"
    log >> "rm -f tmp/tmp.smi tmp/tmp.xyz tmp/tmp.sdf"
    with open("tmp/tmp.smi", "w") as ofs:
        for s in smiles:
            ofs.write("%s\n" % s)
    log >> "%s -d wh -i t=smiles tmp/tmp.smi tmp/tmp.sdf" % corina
    log >> "%s -isdf tmp/tmp.sdf -oxyz tmp/tmp.xyz" % babel
    configs = benchml.read("tmp/tmp.xyz")
    assert len(configs) == len(smiles)
    log >> "rm -f tmp/tmp.smi tmp/tmp.xyz tmp/tmp.sdf"
    log >> "mv corina.trc tmp/."
    return configs


def xyz_to_csv(xyz_file, csv_file):
    configs = benchml.read(xyz_file)
    if len(configs) < 1:
        return
    fields = configs[0].info["csv_fieldnames"]
    for field in sorted(configs[0].info.keys()):
        if field not in fields:
            fields.append(field)
    if "csv_fieldnames" in fields:
        fields.remove("csv_fieldnames")
    writer = csv.DictWriter(open(csv_file, "w"), fieldnames=fields)
    writer.writeheader()
    for config in configs:
        config.info.pop("csv_fieldnames")
        writer.writerow(config.info)


def validate_args(args):
    if args.input == "" and args.from_csv == "":
        raise ValueError("Need to specify either --input or --from_csv")
    if args.to_csv == "":
        if args.output == "":
            raise ValueError("Require --output")
    else:
        if args.input == "":
            raise ValueError("Require --input")
    return args


if __name__ == "__main__":
    log.Connect()
    log.AddArg("input", str, default="", help="Input structure file")
    log.AddArg("filter", str, default="", help="Lambda filter expression")
    log.AddArg("output", str, default="", help="Output structure file")
    # CSV <> XYZ conversion
    log.AddArg("from_csv", str, default="", help="Input csv file")
    log.AddArg("to_csv", str, default="", help="Output csv file")
    log.AddArg("gen3d", "toggle", default=False, help="Generate 3D coords")
    log.AddArg("smiles_from", str, default="SMILES", help="Key from which to read smiles")
    log.AddArg("corina", str, default="/software/corina/bin/corina")
    log.AddArg("babel", str, default="/software/babel/bin/babel")
    args = validate_args(log.Parse())

    configs = None
    if args.to_csv != "":
        xyz_to_csv(args.input, args.to_csv)
        log.okquit()
    if args.from_csv != "":
        configs = csv_to_xyz(args.from_csv, args)
    else:
        log << "Reading" << args.input << log.endl
        configs = benchml.read(args.input)
    if args.filter != "":
        log << "Applying filter:" << args.filter << log.endl
        fct = eval(args.filter)
        log << "Reduced samples from" << len(configs) << log.flush
        configs = list(filter(fct, configs))
        log << "to" << len(configs) << log.endl
    benchml.write(args.output, configs)
