#! /usr/bin/env python
import json
import os

import benchml

log = benchml.log


def load_models_and_dataset(args):
    log << "Load dataset" << log.flush
    dataset = list(benchml.data.DatasetIterator(meta_json=args.meta))
    log << "... done" << log.endl
    models = benchml.models.compile_and_filter([".*"], args.models)
    for model in models:
        if not model.check_available():
            raise RuntimeError("Model '%s' requested but not available" % model.tag)
    if args.override != "":
        if len(dataset) and "override" in dataset[0].meta:
            override_json_str_meta = json.dumps(dataset[0].meta["override"])
        else:
            override_json_str_meta = "{}"
        # NOTE that cmdline config takes precedence over meta config:
        configure_models(models, override_json_str_meta, args.override)
    return models, dataset


def train_models(models, dataset, args):
    for model in models:
        with benchml.sopen(model, dataset) as stream:
            model.precompute(stream)
            if model.hyper is not None:
                model.hyperfit(
                    stream=stream,
                    split_args=dataset["hypersplit"]
                    if "hypersplit" in dataset.meta
                    else {"method": "random", "n_splits": 10, "train_fraction": 0.75},
                    accu_args={"metric": dataset["metrics"][0]},
                    target="y",
                    target_ref="input.y",
                    log=benchml.log,
                    verbose=args.verbose,
                )
            else:
                model.fit(stream, verbose=args.verbose)
        archfile = args.archfile.format(model=model.tag)
        if os.path.dirname(archfile) != "":
            log >> "mkdir -p %s" % os.path.dirname(archfile)
        benchml.save(archfile, model)


def configure_models(models, *override_json_str):
    log << "Configure models" << log.endl
    override = {}
    for jstr in override_json_str:
        override.update(json.loads(jstr))
    for m in models:
        for tf in m.transforms:
            for addr, val in override.items():
                # Note that tf_tag could be both the name
                # or class name of the target transform:
                tf_tag, field = addr.split(".")
                if tf_tag == tf.tag or tf_tag == tf.__class__.__name__:
                    if field not in tf.args:
                        raise KeyError(
                            "Invalid parameter field '%s' in transform '%s'" % (field, tf.tag)
                        )
                    (
                        log
                        << " - Model %s: Override %s[%s].%s = %s"
                        % (m.tag, tf.tag, tf.__class__.__name__, field, val)
                        << log.endl
                    )
                    tf.args[field] = val


def run(args):
    if args.mode == "analyse":
        assert args.benchmark_json != ""  # Require --benchmark_json for mode=analyse
        bench = json.load(open(args.benchmark_json))
        benchml.analysis.analyse(bench)
        return
    if args.mode == "map":
        assert args.extxyz != ""  # Require --extxyz input for mode=map
        assert args.archfile != "{model}.arch"  # Require --archfile input for mode=map
        configs = benchml.read(args.extxyz)
        model = benchml.load(args.archfile)
        with benchml.sopen(model, configs) as stream:
            out = model.map(stream, verbose=args.verbose)
        if args.store_as != "":
            for idx, y in enumerate(out["y"]):
                configs[idx].info[args.store_as] = y
            benchml.write(args.extxyz, configs)
        else:
            for key in out.keys():
                out[key] = out[key].tolist()
            log << json.dumps(out) << log.endl
        return
    benchml.splits.synchronize(args.seed)
    if args.mode == "benchmark":
        assert args.benchmark_json != ""  # Require --benchmark_json for mode=benchmark
        if os.path.dirname(args.benchmark_json):
            log >> "mkdir -p %s" % os.path.dirname(args.benchmark_json)
        models, dataset = load_models_and_dataset(args)
        bench = benchml.benchmark.evaluate(
            data=dataset, models=models, log=benchml.log, verbose=args.verbose, detailed=True
        )
        json.dump(bench, open(args.benchmark_json, "w"), indent=1, sort_keys=True)
    elif args.mode == "train":
        models, dataset = load_models_and_dataset(args)
        for data in dataset:
            train_models(models, data[0], args)


if __name__ == "__main__":
    log.Connect()
    log.AddArg("mode", str, help="Select from benchmark|fit|map|analyse")
    log.AddArg("meta", str, default="meta.json", help="Input metadata file")
    log.AddArg("extxyz", str, default="", help="Input structure-file in ext-xyz format")
    log.AddArg(
        "models", (list, str), default=[], help="List of predefined models for mode=benchmark,train"
    )
    log.AddArg("archfile", str, default="{model}.arch", help="Input model file used when mode=map")
    log.AddArg(
        "store_as",
        str,
        default="",
        help="Key under which predictions are stored in ext-xyz file when mode=map",
    )
    log.AddArg("benchmark_json", str, default="", help="Output json-file storing benchmark results")
    log.AddArg(
        "override",
        str,
        default="",
        help="Json string with name-value pairs for parameter overrides",
    )
    log.AddArg("seed", int, default=971, help="RNG seed")
    log.AddArg("verbose", "toggle", default=False, help="Enable verbose output")
    log.AddArg("use_ase", "toggle", default=False, help="Use ASE parse")
    args = log.Parse()
    if args.use_ase:
        benchml.readwrite.configure(use_ase=args.use_ase)
    if args.mode not in {"benchmark", "train", "map", "analyse"}:
        raise ValueError("Unknown mode '%s'" % args.mode)
    run(args)
