#!/usr/bin/env python3
#
# Copyright (C) 2012, 2013, 2014 Ben Elliston
# Copyright (C) 2014, 2015, 2016 The University of New South Wales
# Copyright (C) 2021 Ben Elliston
#
# This file is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.

"""Evolutionary programming applied to NEM optimisations."""

import argparse
import csv
import json
import sys
import warnings
from argparse import ArgumentDefaultsHelpFormatter as HelpFormatter

import numpy as np
import wx
from deap import algorithms, base, cma, creator, tools
from gooey import Gooey

try:
    from scoop import futures
except ImportError:  # pragma: no cover
    print('WARNING: scoop not loaded')

import nemo
from nemo import configfile as cf
from nemo import costs, demand, penalties, scenarios

# Ignore possible runtime warnings from SCOOP
warnings.simplefilter('ignore', RuntimeWarning)

# pylint: disable=no-member
if wx.PyApp.IsDisplayAvailable() and len(sys.argv) > 1 \
   and '--ignore-gooey' not in sys.argv:
    sys.argv.append('--ignore-gooey')


def conditional_gooey(*pargs, **kwargs):
    """Conditional decorator that wraps the Gooey decorator if the display
    can be found."""
    def decorator(func):
        if not wx.PyApp.IsDisplayAvailable():  # pylint: disable=no-member
            return func
        return Gooey(*pargs, **kwargs)(func)
    return decorator


@conditional_gooey(monospaced_font=True,
                   program_name="NEMO evolution",
                   richtext_controls=True,
                   show_success_modal=False,
                   disable_progress_bar_animation=True)
def process_options():
    """Process options and return an argparse object."""
    epilog = 'Bug reports via https://nemo.ozlabs.org/'
    parser = argparse.ArgumentParser(epilog=epilog,
                                     formatter_class=HelpFormatter,
                                     add_help=False)
    comgroup = parser.add_argument_group('common', 'Commonly used options')
    costgroup = parser.add_argument_group('costs', 'Cost-related options')
    limitgroup = parser.add_argument_group('limits',
                                           'Limits/constraints for the model')
    optgroup = parser.add_argument_group('optimiser', 'CMA-ES controls')

    comgroup.add_argument("-d", "--demand-modifier", action="append",
                          help='demand modifier')
    comgroup.add_argument("-h", "--help", action="help",
                          help="show this help message and exit")
    comgroup.add_argument("--list-scenarios", action="store_true",
                          help='list supply scenarios and exit')
    comgroup.add_argument("-o", "--output", type=str, default='results.json',
                          help='output filename (will overwrite)')
    comgroup.add_argument("-p", "--plot", action="store_true",
                          help='plot hourly energy balance on completion')
    comgroup.add_argument("--reliability-std", type=float, default=0.002,
                          help='reliability standard (%% unserved)')
    comgroup.add_argument("--reserves", type=int,
                          default=cf.get('limits', 'minimum-reserves-mw'),
                          help='minimum operating reserves (MW)')
    comgroup.add_argument("-s", "--supply-scenario", type=str,
                          default='ccgt', metavar='SCENARIO',
                          choices=sorted(scenarios.supply_scenarios),
                          help='generation mix scenario')

    costgroup.add_argument("-c", "--carbon-price", type=int,
                           default=cf.get('costs', 'co2-price-per-t'),
                           help='carbon price ($/t)')
    costgroup.add_argument("--costs", type=str, metavar='cost_class',
                           default=cf.get('costs', 'technology-cost-class'),
                           choices=sorted(costs.cost_scenarios),
                           help='technology cost class')
    costgroup.add_argument("--ccs-storage-costs", type=float,
                           default=cf.get('costs', 'ccs-storage-costs-per-t'),
                           help='CCS storage costs ($/t)')
    costgroup.add_argument("--coal-price", type=float,
                           default=cf.get('costs', 'coal-price-per-gj'),
                           help='black coal price ($/GJ)')
    costgroup.add_argument("--gas-price", type=float,
                           default=cf.get('costs', 'gas-price-per-gj'),
                           help='gas price ($/GJ)')
    costgroup.add_argument("-r", "--discount-rate", type=float,
                           default=cf.get('costs', 'discount-rate'),
                           help='discount rate')

    limitgroup.add_argument("--bioenergy-limit", type=float,
                            default=cf.get('limits', 'bioenergy-twh-per-yr'),
                            help='Limit on annual bioenergy use (TWh/y)')
    limitgroup.add_argument("--emissions-limit", type=float, default=np.inf,
                            help='CO2 emissions limit (Mt/y)')
    limitgroup.add_argument("--fossil-limit", type=float, default=1.0,
                            help='Maximum share of energy from fossil fuel')
    limitgroup.add_argument("--hydro-limit", type=float,
                            default=cf.get('limits', 'hydro-twh-per-yr'),
                            help='Limit on annual energy from hydro (TWh/y)')
    limitgroup.add_argument("--min-regional-generation", type=float,
                            default=0.0,
                            help='minimum share of intra-region generation')
    limitgroup.add_argument("--nsp-limit", type=float,
                            default=cf.get('limits', 'nonsync-penetration'),
                            help='Non-synchronous penetration limit')

    optgroup.add_argument("--lambda", type=int, dest='lambda_',
                          help='override CMA-ES lambda value')
    if cf.has_option_p('optimiser', 'seed'):
        seed_default = cf.get('optimiser', 'seed')
    else:
        seed_default = None
    optgroup.add_argument("--seed", type=int,
                          default=seed_default,
                          help='seed for random number generator')
    optgroup.add_argument("--sigma", type=float,
                          default=cf.get('optimiser', 'sigma'),
                          help='CMA-ES sigma value')
    optgroup.add_argument("-g", "--generations", type=int,
                          default=cf.get('optimiser', 'generations'),
                          help='generations')
    optgroup.add_argument("--trace-file", type=str,
                          help='Filename for evaluation trace (CSV format)')
    optgroup.add_argument("-v", "--verbose", action="store_true",
                          help="be verbose")
    return parser.parse_args()


def setup_context():
    """Set up the context object based on command line arguments."""
    ctx = nemo.Context()
    ctx.relstd = args.reliability_std

    # Set the system non-synchronous penetration limit.
    ctx.nsp_limit = args.nsp_limit
    assert 0 <= ctx.nsp_limit <= 1

    # Likewise for the minimum share of regional generation.
    ctx.min_regional_generation = args.min_regional_generation
    assert 0 <= ctx.min_regional_generation <= 1, \
        "Minimum regional generation must be in the interval [0,1]"

    cost_class = costs.cost_scenarios[args.costs]
    ctx.costs = cost_class(args.discount_rate, args.coal_price,
                           args.gas_price, args.ccs_storage_costs)
    ctx.costs.carbon = args.carbon_price
    return ctx


def list_scenarios():
    """Print out a list of the scenarios with a description."""
    for key in sorted(scenarios.supply_scenarios):
        doc = scenarios.supply_scenarios[key].__doc__.split('\n')
        description = next(line for line in doc if line).strip()
        print(f'{key:>20}', '\t', description)
    sys.exit(0)


args = process_options()

if __name__ == '__main__' and args.list_scenarios:
    list_scenarios()

if __name__ == '__main__':
    print(vars(args))

context = setup_context()

# Set up the scenario.
scenarios.supply_scenarios[args.supply_scenario](context)

# Apply each demand modifier argument (if any) in the given order.
for arg in args.demand_modifier or []:
    demand.switch(arg)(context)

if args.verbose and __name__ == '__main__':
    docstring = scenarios.supply_scenarios[args.supply_scenario].__doc__
    assert docstring is not None
    # Prune off any doctest test from the docstring.
    docstring = docstring.split('\n')[0]
    print(f"supply scenario: {args.supply_scenario} ({docstring})")
    print(context.generators)

if args.trace_file is not None:
    with open(args.trace_file, 'w', encoding='utf-8') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(['# score', 'penalty', 'reasoncode',
                         'parameter values'])


# Build list of penalty functions based on command line args, etc.
penaltyfns = [penalties.unserved, penalties.bioenergy, penalties.hydro]
if args.reserves > 0:
    penaltyfns.append(penalties.reserves)
if args.emissions_limit < np.inf:
    penaltyfns.append(penalties.emissions)
if args.fossil_limit < 1:
    penaltyfns.append(penalties.fossil)
if context.min_regional_generation > 0:
    penaltyfns.append(penalties.min_regional)


def cost(ctx):
    """Sum up the costs."""
    score = 0
    for gen in ctx.generators:
        score += (gen.capcost(ctx.costs) / ctx.costs.annuityf * ctx.years) \
            + gen.opcost(ctx.costs)

    # Run through all of the penalty functions.
    penalty, reason = 0, 0
    for penaltyfn in penaltyfns:
        pvalue, rcode = penaltyfn(ctx, args)
        penalty += pvalue
        reason |= rcode

    score /= ctx.total_demand()
    penalty /= ctx.total_demand()
    # Express $/yr as an average $/MWh over the period
    return score, penalty, reason


def eval_func(chromosome):
    """Average cost of energy (in $/MWh)."""
    context.set_capacities(chromosome)
    nemo.run(context)
    score, penalty, reason = cost(context)
    if args.trace_file is not None:
        # write the score and individual to the trace file
        with open(args.trace_file, 'a', encoding='utf-8') as tracefile:
            tracer = csv.writer(tracefile)
            tracer.writerow([score, penalty, reason] + list(chromosome))
    return (score + penalty,)


creator.create("FitnessMin", base.Fitness, weights=(-1.0,))
creator.create("Individual", list, fitness=creator.FitnessMin)
toolbox = base.Toolbox()
try:
    toolbox.register("map", futures.map)
except NameError:  # pragma: no cover
    pass

# See:
# https://deap.readthedocs.org/en/master/api/algo.html#deap.cma.Strategy
# for additional parameters that can be passed to cma.Strategy.
numparams = sum(list(len(g.setters) for g in context.generators))
if args.lambda_ is None:
    # let DEAP choose
    strategy = cma.Strategy(centroid=[0] * numparams, sigma=args.sigma)
else:
    strategy = cma.Strategy(centroid=[0] * numparams, sigma=args.sigma,
                            lambda_=args.lambda_)

toolbox.register("generate", strategy.generate, creator.Individual)
toolbox.register("update", strategy.update)
toolbox.register("evaluate", eval_func)


def run():
    """Run the evolution."""
    if args.verbose and __name__ == '__main__':
        print("objective: minimise", eval_func.__doc__)

    np.random.seed(args.seed)
    hof = tools.HallOfFame(1)
    stats_fit = tools.Statistics(lambda ind: ind.fitness.values)
    stats_hof = tools.Statistics(lambda _: hof[0].fitness.values)
    mstats = tools.MultiStatistics(fitness=stats_fit, hallfame=stats_hof)
    mstats.register("min", np.min)

    try:
        algorithms.eaGenerateUpdate(toolbox, ngen=args.generations,
                                    stats=mstats, halloffame=hof, verbose=True)
    except KeyboardInterrupt:  # pragma: no cover
        print('user terminated early')

    context.set_capacities(hof[0])
    nemo.run(context)
    context.verbose = True
    print()
    print(context)
    score, penalty, reason = cost(context)
    print(f'Score: {score:.2f} $/MWh')
    constraints_violated = []
    if reason > 0:
        print(f'Penalty: {penalty:.2f} $/MWh')
        print('Constraints violated:', end=' ')
        for label, code in penalties.reasons.items():
            if reason & code:
                constraints_violated += [label]
                print(label, end=' ')
        print()

    with open(args.output, 'w', encoding='utf-8') as filehandle:
        bundle = {'options': vars(args),
                  'parameters': [max(0, cap) for cap in hof[0]],
                  'score': score, 'penalty': penalty,
                  'constraints_violated': constraints_violated}
        json.dump(bundle, filehandle)
    print('Done')

    if args.plot:
        nemo.plot(context)


if __name__ == '__main__':
    run()
