#!/usr/bin/env python3
#
# Copyright (C) 2012, 2013, 2014 Ben Elliston
# Copyright (C) 2014, 2015, 2016 The University of New South Wales
# Copyright (C) 2021, 2023 Ben Elliston
#
# This file is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.

"""Evolutionary programming applied to NEM optimisations."""

import argparse
import csv
import json
import sys
from argparse import ArgumentDefaultsHelpFormatter as HelpFormatter
from multiprocessing import set_start_method
from multiprocessing.pool import Pool

import numpy as np
import wx
from deap import algorithms, base, cma, creator, tools
from gooey import Gooey

import nemo
from nemo import configfile as cf
from nemo import costs, demand, penalties, scenarios

if __name__ == '__main__':
    if wx.PyApp.IsDisplayAvailable() and len(sys.argv) > 1 \
       and '--ignore-gooey' not in sys.argv:
        sys.argv.append('--ignore-gooey')


def init_worker(arguments):
    """Initialise worker processes."""
    # pylint: disable=global-statement
    global args
    global penaltyfns
    # pylint: disable=global-variable-undefined
    global context
    args = arguments
    context = setup_context(args)
    penaltyfns = penaltyfn_list(context)


def conditional_gooey(*pargs, **kwargs):
    """Conditional decorator that wraps the Gooey decorator if the display
    can be found."""
    def decorator(func):
        if not wx.PyApp.IsDisplayAvailable():  # pylint: disable=no-member
            return func
        return Gooey(*pargs, **kwargs)(func)
    return decorator


@conditional_gooey(monospaced_font=True,
                   program_name="NEMO evolution",
                   richtext_controls=True,
                   show_success_modal=False,
                   disable_progress_bar_animation=True)
def process_options():
    """Process options and return an argparse object."""
    epilog = 'Bug reports via https://nemo.ozlabs.org/'
    parser = argparse.ArgumentParser(epilog=epilog,
                                     formatter_class=HelpFormatter,
                                     add_help=False)
    comgroup = parser.add_argument_group('common', 'Commonly used options')
    costgroup = parser.add_argument_group('costs', 'Cost-related options')
    limitgroup = parser.add_argument_group('limits',
                                           'Limits/constraints for the model')
    optgroup = parser.add_argument_group('optimiser', 'CMA-ES controls')

    comgroup.add_argument("-d", "--demand-modifier", action="append",
                          help='demand modifier')
    comgroup.add_argument("-h", "--help", action="help",
                          help="show this help message and exit")
    comgroup.add_argument("--list-scenarios", action="store_true",
                          help='list supply scenarios and exit')
    comgroup.add_argument("-o", "--output", type=str, default='results.json',
                          help='output filename (will overwrite)')
    comgroup.add_argument("-p", "--plot", action="store_true",
                          help='plot hourly energy balance on completion')
    comgroup.add_argument("--reliability-std", type=float, default=0.002,
                          help='reliability standard (%% unserved)')
    comgroup.add_argument("--reserves", type=int,
                          default=cf.get('limits', 'minimum-reserves-mw'),
                          help='minimum operating reserves (MW)')
    comgroup.add_argument("-s", "--supply-scenario", type=str,
                          default='ccgt', metavar='SCENARIO',
                          choices=sorted(scenarios.supply_scenarios),
                          help='generation mix scenario')

    costgroup.add_argument("-c", "--carbon-price", type=int,
                           default=cf.get('costs', 'co2-price-per-t'),
                           help='carbon price ($/t)')
    costgroup.add_argument("--costs", type=str, metavar='cost_class',
                           default=cf.get('costs', 'technology-cost-class'),
                           choices=sorted(costs.cost_scenarios),
                           help='technology cost class')
    costgroup.add_argument("--ccs-storage-costs", type=float,
                           default=cf.get('costs', 'ccs-storage-costs-per-t'),
                           help='CCS storage costs ($/t)')
    costgroup.add_argument("--coal-price", type=float,
                           default=cf.get('costs', 'coal-price-per-gj'),
                           help='black coal price ($/GJ)')
    costgroup.add_argument("--gas-price", type=float,
                           default=cf.get('costs', 'gas-price-per-gj'),
                           help='gas price ($/GJ)')
    costgroup.add_argument("-r", "--discount-rate", type=float,
                           default=cf.get('costs', 'discount-rate'),
                           help='discount rate')

    limitgroup.add_argument("--bioenergy-limit", type=float,
                            default=cf.get('limits', 'bioenergy-twh-per-yr'),
                            help='Limit on annual bioenergy use (TWh/y)')
    limitgroup.add_argument("--emissions-limit", type=float, default=np.inf,
                            help='CO2 emissions limit (Mt/y)')
    limitgroup.add_argument("--fossil-limit", type=float, default=1.0,
                            help='Maximum share of energy from fossil fuel')
    limitgroup.add_argument("--hydro-limit", type=float,
                            default=cf.get('limits', 'hydro-twh-per-yr'),
                            help='Limit on annual energy from hydro (TWh/y)')
    limitgroup.add_argument("--min-regional-generation", type=float,
                            default=0.0,
                            help='minimum share of intra-region generation')
    limitgroup.add_argument("--nsp-limit", type=float,
                            default=cf.get('limits', 'nonsync-penetration'),
                            help='Non-synchronous penetration limit')

    optgroup.add_argument("--lambda", type=int, dest='lambda_',
                          help='override CMA-ES lambda value')
    if cf.has_option_p('optimiser', 'seed'):
        seed_default = cf.get('optimiser', 'seed')
    else:
        seed_default = None
    optgroup.add_argument("-n", "--ncpus", type=int,
                          default=cf.get('optimiser', 'num-cpus'),
                          help='number of CPUs to use for parallel execution')
    optgroup.add_argument("--seed", type=int,
                          default=seed_default,
                          help='seed for random number generator')
    optgroup.add_argument("--sigma", type=float,
                          default=cf.get('optimiser', 'sigma'),
                          help='CMA-ES sigma value')
    optgroup.add_argument("-g", "--generations", type=int,
                          default=cf.get('optimiser', 'generations'),
                          help='generations')
    optgroup.add_argument("--trace-file", type=str,
                          help='Filename for evaluation trace (CSV format)')
    optgroup.add_argument("-v", "--verbose", action="store_true",
                          help="be verbose")
    return parser.parse_args()


def setup_context(args):
    """Set up the context object based on command line arguments."""
    # pylint: disable=redefined-outer-name
    ctx = nemo.Context()
    ctx.relstd = args.reliability_std

    # Set the system non-synchronous penetration limit.
    ctx.nsp_limit = args.nsp_limit
    assert 0 <= ctx.nsp_limit <= 1

    # Likewise for the minimum share of regional generation.
    ctx.min_regional_generation = args.min_regional_generation
    assert 0 <= ctx.min_regional_generation <= 1, \
        "Minimum regional generation must be in the interval [0,1]"

    cost_class = costs.cost_scenarios[args.costs]
    ctx.costs = cost_class(args.discount_rate, args.coal_price,
                           args.gas_price, args.ccs_storage_costs)
    ctx.costs.carbon = args.carbon_price

    # Set up supply scenario.
    scenarios.supply_scenarios[args.supply_scenario](ctx)

    # Apply each demand modifier argument (if any) in the given order.
    for arg in args.demand_modifier or []:
        demand.switch(arg)(ctx)
    return ctx


def list_scenarios():
    """Print out a list of the scenarios with a description."""
    for key in sorted(scenarios.supply_scenarios):
        doc = scenarios.supply_scenarios[key].__doc__.split('\n')
        description = next(line for line in doc if line).strip()
        print(f'{key:>20}', '\t', description)
    sys.exit(0)


def penaltyfn_list(ctx):
    """Build list of penalty functions based on command line args, etc."""
    lst = [penalties.unserved, penalties.bioenergy, penalties.hydro]
    if args.reserves > 0:
        lst.append(penalties.reserves)
    if args.emissions_limit < np.inf:
        lst.append(penalties.emissions)
    if args.fossil_limit < 1:
        lst.append(penalties.fossil)
    if ctx.min_regional_generation > 0:
        lst.append(penalties.min_regional)
    return lst


def cost(ctx):
    """Sum up the costs."""
    score = 0
    for gen in ctx.generators:
        score += (gen.capcost(ctx.costs) / ctx.costs.annuityf * ctx.years) \
            + gen.opcost(ctx.costs)

    # Run through all of the penalty functions.
    penalty, reason = 0, 0
    for penaltyfn in penaltyfns:
        pvalue, rcode = penaltyfn(ctx, args)
        penalty += pvalue
        reason |= rcode

    score /= ctx.total_demand()
    penalty /= ctx.total_demand()
    # Express $/yr as an average $/MWh over the period
    return score, penalty, reason


def eval_func(chromosome):
    """Average cost of energy (in $/MWh)."""
    context.set_capacities(chromosome)
    nemo.run(context)
    score, penalty, reason = cost(context)
    if args.trace_file is not None:
        # write the score and individual to the trace file
        with open(args.trace_file, 'a', encoding='utf-8') as tracefile:
            tracer = csv.writer(tracefile)
            tracer.writerow([score, penalty, reason] + list(chromosome))
    return (score + penalty,)


def run_final(best):
    """Run the simulation with the best candidate."""
    main_context.set_capacities(best)
    nemo.run(main_context)
    main_context.verbose = True
    print()
    print(main_context)
    score, penalty, reason = cost(main_context)
    print(f'Score: {score:.2f} $/MWh')
    constraints_violated = []
    if reason > 0:
        print(f'Penalty: {penalty:.2f} $/MWh')
        print('Constraints violated:', end=' ')
        for label, code in penalties.reasons.items():
            if reason & code:
                constraints_violated += [label]
                print(label, end=' ')
        print()

    with open(args.output, 'w', encoding='utf-8') as filehandle:
        bundle = {'options': vars(args),
                  'parameters': [max(0, cap) for cap in best],
                  'score': score, 'penalty': penalty,
                  'constraints_violated': constraints_violated}
        json.dump(bundle, filehandle)


def run():
    """Run the evolution."""
    if args.verbose:
        docstring = scenarios.supply_scenarios[args.supply_scenario].__doc__
        assert docstring is not None
        # Prune off any doctest test from the docstring.
        docstring = docstring.split('\n')[0]
        print(f"supply scenario: {args.supply_scenario} ({docstring})")
        print("objective: minimise", eval_func.__doc__)

    np.random.seed(args.seed)
    hof = tools.HallOfFame(1)
    stats_fit = tools.Statistics(lambda ind: ind.fitness.values)
    stats_hof = tools.Statistics(lambda _: hof[0].fitness.values)
    mstats = tools.MultiStatistics(fitness=stats_fit, hallfame=stats_hof)
    mstats.register("min", np.min)

    try:
        algorithms.eaGenerateUpdate(toolbox, ngen=args.generations,
                                    stats=mstats, halloffame=hof, verbose=True)
    except KeyboardInterrupt:  # pragma: no cover
        print('user terminated early')

    run_final(hof[0])
    print('Done')

    if args.plot:
        nemo.plot(main_context)


creator.create("FitnessMin", base.Fitness, weights=(-1.0,))
creator.create("Individual", list, fitness=creator.FitnessMin)

if __name__ == '__main__':
    args = process_options()
    if args.list_scenarios:
        list_scenarios()
    print(vars(args))

    # See:
    # https://deap.readthedocs.org/en/master/api/algo.html#deap.cma.Strategy
    # for additional parameters that can be passed to cma.Strategy.
    main_context = setup_context(args)
    scenarios.supply_scenarios[args.supply_scenario](main_context)
    penaltyfns = penaltyfn_list(main_context)

    numparams = sum(list(len(g.setters) for g in main_context.generators))
    if args.lambda_ is None:
        # let DEAP choose
        strategy = cma.Strategy(centroid=[0] * numparams, sigma=args.sigma)
    else:
        strategy = cma.Strategy(centroid=[0] * numparams, sigma=args.sigma,
                                lambda_=args.lambda_)

    toolbox = base.Toolbox()
    toolbox.register("generate", strategy.generate, creator.Individual)
    toolbox.register("update", strategy.update)
    toolbox.register("evaluate", eval_func)

    set_start_method('spawn')
    with Pool(args.ncpus if args.ncpus else None,
              initializer=init_worker, initargs=(args,)) as pool:
        toolbox.register("map", pool.map)
        run()
        pool.close()
        pool.join()
