#! /usr/bin/env python
import numpy as np
import json
import pandas as pd
import benchml
import csv
log = benchml.log
log >> 'mkdir -p tmp'

def prompt(mssg, options=[], default=None, check_allowed=lambda i,o: i in o, convert=str):
    prompt_str = "%-20s" % (mssg+":")
    if default is not None:
        def_str = "   default="
        def_str += str(default)
        def_str += ""
        prompt_str += "%-40s" % def_str
    if len(options) > 0:
        opt_str = "   options=[ "
        opt_str += " | ".join(map(str, options))
        opt_str += " ]"
        prompt_str += "%-50s" % opt_str
    log << prompt_str << log.endl
    while True:
        s = input('>>> ')
        valid = True
        if s != "":
            try:
                s = convert(s)
                if len(options) > 0 and not check_allowed(s, options):
                    raise ValueError()
            except:
                valid = False
        elif default is None:
            valid = False 
        else:
            s = default
        if valid:
            break
        else:
            log << log.mr << "  Invalid input '%s'" % str(s) << log.endl
    return s 

def write_meta(configs, extxyz, meta_json):
    if len(configs) < 1:
        log.my << "ERROR - No structures in dataset" << log.endl
        return
    meta = {}
    fields = sorted(configs[0].info.keys())
    log << log.mg << "Adding metadata:" << log.endl
    log << "Add target:" << log.endl
    for fidx, f in enumerate(fields):
        log << "    Enter '%d' to select field '%s'" % (fidx, f) << log.endl
    fidx = prompt("Select target field by index (see above)", 
        options=range(len(fields)), default=None, convert=int)
    task = prompt("Define task", 
        options=["regression", "classification"], default="regression")
    scaling = prompt("Estimate scaling", 
        options=["extensive", "intensive", "unknown"], default="unknown")
    metrics = prompt("Select metrics", 
        options=["mae", "rmse", "mse", "rhop", "r2", "auc", "rhor"], 
        default=["mae", "rmse", "rhop", "r2"] if (task == "regression") else ["auc"], 
        check_allowed=lambda i,o: len(set(i).intersection(set(o))) == len(i),
        convert=lambda s: s.split())
    per_atom = False
    conv = prompt("Set conversion",
        options=["''", "log", "log10", "plog"],
        default="")
    elements = set()
    for config in configs:
        elements = elements.union(set(list(config.symbols)))
    elements = sorted(list(elements))
    elements_add = prompt("Add chemical elements",
        default=elements, convert=lambda s: s.split())
    elements = sorted(list(set(elements).union(set(elements_add))))
    has_smiles = True if ("smiles" in fields or "SMILES" in fields) else False
    periodic = True if ("cell" in fields or "lattice" in fields) else False
    datasets = [ extxyz ]
    name = prompt("Enter name",
        default=extxyz)
    comment = prompt("Enter comment",
        default="")
    meta = {
        "comment": comment,
        "name": name,
        "targets": {
            fields[fidx]: {
                "task": task,
                "scaling": scaling,
                "metrics": metrics,
                "per_atom": per_atom,
                "convert": conv
            }
        },
        "splits": [ {
             "method": "sequential", 
             "train_fraction": "np.arange(0.1, 1.0, 0.1)", 
             "repeat_fraction_fct": "lambda s,t,p,f: 2*int(1./(f*(1-f))**0.5)" 
           } 
        ],
        "hypersplit": {
            "method": "random", 
            "n_splits": 10, 
            "train_fraction": 0.75 },
        "elements": elements,
        "has_smiles": has_smiles,
        "periodic": periodic,
        "datasets": datasets
    }
    log << "Writing metadata to" << meta_json << log.endl
    json.dump(meta, open(meta_json, "w"), indent=1, sort_keys=True)
    return meta

if __name__ == "__main__":
    log.Connect()
    log.AddArg("extxyz", str)
    log.AddArg("meta", str)
    args = log.Parse()

    configs = benchml.read(args.extxyz)
    write_meta(configs, args.extxyz, args.meta)
    
