#! /usr/bin/env python
import json

import benchml

log = benchml.log
log >> "mkdir -p tmp"


def prompt(mssg, options=[], default=None, check_allowed=lambda i, o: i in o, convert=str):
    prompt_str = "%-20s" % (mssg + ":")
    if default is not None:
        def_str = "   default="
        def_str += str(default)
        def_str += ""
        prompt_str += "%-40s" % def_str
    if len(options) > 0:
        opt_str = "   options=[ "
        opt_str += " | ".join(map(str, options))
        opt_str += " ]"
        prompt_str += "%-50s" % opt_str
    log << prompt_str << log.endl
    while True:
        s = input(">>> ")
        valid = True
        if s != "":
            try:
                s = convert(s)
                if len(options) > 0 and not check_allowed(s, options):
                    raise ValueError()
            except Exception:
                valid = False
        elif default is None:
            valid = False
        else:
            s = default
        if valid:
            break
        else:
            log << log.mr << "  Invalid input '%s'" % str(s) << log.endl
    return s


def write_meta(configs, extxyz, meta_json):
    if len(configs) < 1:
        log.my << "ERROR - No structures in dataset" << log.endl
        return
    fields = sorted(configs[0].info.keys())
    log << log.mg << "Adding metadata:" << log.endl
    log << "Add target:" << log.endl
    for fidx, f in enumerate(fields):
        log << "    Enter '%d' to select field '%s'" % (fidx, f) << log.endl
    fidx = prompt(
        "Select target field by index (see above)",
        options=range(len(fields)),
        default=None,
        convert=int,
    )
    task = prompt("Define task", options=["regression", "classification"], default="regression")
    scaling = prompt(
        "Estimate scaling", options=["extensive", "intensive", "unknown"], default="unknown"
    )
    metrics = prompt(
        "Select metrics",
        options=["mae", "rmse", "mse", "rhop", "r2", "auc", "rhor"],
        default=["mae", "rmse", "rhop", "r2"] if (task == "regression") else ["auc"],
        check_allowed=lambda i, o: len(set(i).intersection(set(o))) == len(i),
        convert=lambda s: s.split(),
    )
    per_atom = False
    conv = prompt("Set conversion", options=["''", "log", "log10", "-log10"], default="")
    elements = set()
    for config in configs:
        elements = elements.union(set(list(config.symbols)))
    elements = sorted(list(elements))
    elements_add = prompt("Add chemical elements", default=elements, convert=lambda s: s.split())
    elements = sorted(list(set(elements).union(set(elements_add))))
    has_smiles = True if ("smiles" in fields or "SMILES" in fields) else False
    periodic = True if ("cell" in fields or "lattice" in fields) else False
    datasets = extxyz
    name = prompt("Enter name", default=extxyz)
    comment = prompt("Enter comment", default="")
    meta = {
        "comment": comment,
        "name": name,
        "targets": {
            fields[fidx]: {
                "task": task,
                "scaling": scaling,
                "metrics": metrics,
                "per_atom": per_atom,
                "convert": conv,
            }
        },
        "splits": [
            {
                "method": "sequential",
                "train_fraction": "np.arange(0.1, 1.0, 0.1)",
                "repeat_fraction_fct": "lambda s,t,p,f: 2*int(1./(f*(1-f))**0.5)",
            }
        ],
        "hypersplit": {"method": "random", "n_splits": 10, "train_fraction": 0.75},
        "elements": elements,
        "has_smiles": has_smiles,
        "periodic": periodic,
        "datasets": datasets,
    }
    log << "Writing metadata to" << meta_json << log.endl
    json.dump(meta, open(meta_json, "w"), indent=1, sort_keys=True)
    return meta


if __name__ == "__main__":
    log.Connect()
    log.AddArg("extxyz", (list, str))
    log.AddArg("meta", str)
    args = log.Parse()

    configs = []
    for f in args.extxyz:
        configs.extend(benchml.read(f))
    write_meta(configs, args.extxyz, args.meta)
