#! /usr/bin/env python
import json
import os
import benchml
from benchml.transforms import *
log = benchml.log

def load_models_and_dataset(args):
    log << "Load dataset" << log.flush
    dataset = list(benchml.data.DatasetIterator(meta_json=args.meta))
    log << "... done" << log.endl
    models = benchml.models.compile_and_filter([".*"], args.models)
    for model in models:
        if not model.check_available():
            raise RuntimeError("Model '%s' requested but not available" % model.tag)
    if args.override != "": 
        if len(dataset) and "override" in dataset[0].meta:
            override_json_str_meta = json.dumps(dataset[0].meta["override"])
        else:
            override_json_str_meta = "{}"
        # NOTE that cmdline config takes precedence over meta config:
        configure_models(models, override_json_str_meta, args.override)
    return models, dataset

def train_models(models, dataset, args):
    for model in models:
        with benchml.sopen(model, dataset) as stream:
            model.precompute(stream)
            if model.hyper is not None:
                model.hyperfit(
                    stream=stream,
                    split_args=dataset["hypersplit"] if "hypersplit" in dataset.meta \
                        else {"method": "random", "n_splits": 10, "train_fraction": 0.75},
                    accu_args={"metric": dataset["metrics"][0]},
                    target="y",
                    target_ref="input.y",
                    log=benchml.log,
                    verbose=args.verbose)
            else:
                model.fit(stream, verbose=args.verbose)
        archfile = args.archfile.format(model=model.tag)
        if os.path.dirname(archfile) != '':
            log >> 'mkdir -p %s' % os.path.dirname(archfile)
        benchml.save(archfile, model)
    return model

def configure_models(
        models, 
        *override_json_str):
    log << "Configure models" << log.endl
    override = {}
    for jstr in override_json_str:
        override.update(json.loads(jstr))
    for m in models:
        for tf in m.transforms:
            for addr, val in override.items():
                # Note that tf_tag could be both the name 
                # or class name of the target transform:
                tf_tag, field = addr.split(".")
                if tf_tag == tf.tag or tf_tag == tf.__class__.__name__:
                    if not field in tf.args:
                        raise KeyError(
                            "Invalid parameter field '%s' in transform '%s'" % (
                            field, tf.tag))
                    log << " - Model %s: Override %s[%s].%s = %s" % (
                        m.tag, tf.tag, tf.__class__.__name__, field, val) << log.endl
                    tf.args[field] = val

def run(args):
    if args.mode == "analyse":
        assert args.benchmark_json != "" # Require --benchmark_json for mode=analyse
        bench = json.load(open(args.benchmark_json))
        benchml.analysis.analyse(bench)
        return
    if args.mode == "map":
        assert args.extxyz != "" # Require --extxyz input for mode=map
        assert args.archfile != "{model}.arch" # Require --archfile input for mode=map
        configs = benchml.read(args.extxyz)
        model = benchml.load(args.archfile)
        with benchml.sopen(model, configs) as stream:
            out = model.map(stream, verbose=args.verbose)
        if args.store_as != "":
            for idx, y in enumerate(out["y"]):
                configs[idx].info[args.store_as] = y
            benchml.write(args.extxyz, configs)
        else:
            for key in out.keys():
                out[key] = out[key].tolist()
            log << json.dumps(out) << log.endl
        return
    benchml.splits.synchronize(args.seed)
    if args.mode == "benchmark":
        assert args.benchmark_json != "" # Require --benchmark_json for mode=benchmark
        if os.path.dirname(args.benchmark_json):
            log >> 'mkdir -p %s' % os.path.dirname(args.benchmark_json)
        models, dataset = load_models_and_dataset(args)
        bench = benchml.benchmark.evaluate(
            data=dataset,
            models=models, 
            log=benchml.log, 
            verbose=args.verbose,
            detailed=True)
        json.dump(bench, open(args.benchmark_json, "w"), indent=1, sort_keys=True)
    elif args.mode == "train":
        models, dataset = load_models_and_dataset(args)
        for data in dataset:
            train_models(models, dataset[0], args)
            break

if __name__ == "__main__":
    log.Connect()
    log.AddArg("mode", str, help="Select from benchmark|fit|map|analyse")
    log.AddArg("meta", str, default="meta.json", help="Input metadata file")
    log.AddArg("extxyz", str, default="", help="Input structure-file in ext-xyz format")
    log.AddArg("models", (list,str), default=[], help="List of predefined models for mode=benchmark,train")
    log.AddArg("archfile", str, default="{model}.arch", help="Input model file used when mode=map")
    log.AddArg("store_as", str, default="", help="Key under which predictions are stored in ext-xyz file when mode=map")
    log.AddArg("benchmark_json", str, default="", help="Output json-file storing benchmark results")
    log.AddArg("override", str, default="", help="Json string with name-value pairs for parameter overrides")
    log.AddArg("seed", int, default=971, help="RNG seed")
    log.AddArg("verbose", "toggle", default=False, help="Enable verbose output")
    log.AddArg("use_ase", "toggle", default=False, help="Use ASE parse")
    args = log.Parse()
    if args.use_ase:
        benchml.readwrite.configure(use_ase=args.use_ase)    
    if not args.mode in {"benchmark", "train", "map", "analyse"}:
        raise ValueError("Unknown mode '%s'" % args.mode)
    run(args)
