#!/usr/bin/env python3
"""
A template file to simulate a specific system of Ordinary Differential Equations (ODEs).

... or an autogenerated script from the *crnsimulator* Python package.

Note: If this file is executable, it is *autogenerated*.
    This means it contains a system of hardcoded ODEs together with some
    default parameters. It may be helfpul to teak e.g.  the ode_plotter
    function, but beware that this file may be overwritten by the next
    execution of the `crnsimulator` executable. 
    For heavy edits, is recommended to edit the source directly at:
    "crnsimulator.odelib_template" or provide your own template file.
    Alternatively, use the option --output to prevent the loss of your edits.

Usage: 
    python #<&>FILENAME<&># --help
"""

import logging
logger = logging.getLogger(__name__)

import argparse
import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import odeint

import seaborn as sns
sns.set(style="darkgrid", font_scale=1, rc={"lines.linewidth": 2.0})

class ODETemplateError(Exception):
    pass


def ode_plotter(name, t, ny, svars, log = False, labels = None,
        xlim = None, ylim = None, plim = None, labels_strict = False):
    """ Plots the ODE trajectories.

    Args:
      name (str): Name of the outputfile including extension (e.g. *.pdf)
      t (list[flt]) : time units plotted on the x-axis.
      ny (list[list[flt]]) : a list of trajectories plotted on the y-axis.
      svars (list[str]): A list of names for every trajectory in ny
      log (bool,optional): Plot data on a logarithmic time scale
      labels (set(),optional): Define species that appear labelled in the plot
      xlim ((float,float), optional): matplotlib xlim.
      ylim ((float,float), optional): matplotlib ylim.
      plim (float, optional): Minimal occupancy to plot a trajectory. Defaults to None.
      labels_strict (bool, optional): Only print labels that were specified using labels.

    Prints:
      A file containing the plot (Format *.pdf, *.png, etc.)

    Returns:
      [str]: Name of the file containing the plot
    """
    fig, ax = plt.subplots(1, 1, figsize=(8, 4.5))

    # b : blue.
    # g : green.
    # r : red.
    # c : cyan.
    # m : magenta.
    # y : yellow.
    # k : black.
    mycolors = ['blue', 
                'red', 
                'green', 
                'orange', 
                'maroon', 
                'springgreen', 
                'cyan', 
                'magenta', 
                'yellow']
    mycolors += list('kkkkkkkkkkk')

    if labels:
        i = 0
        for e, y in enumerate(ny):
            if svars[e] in labels:
                ax.plot(t, y, '-', label=svars[e], color=mycolors[i])
                i = i + 1 if i < len(mycolors) - 1 else 0
            elif not labels_strict:
                ax.plot(t, y, '--', lw=0.1, color='gray', zorder=1)
    else:
        for e, y in enumerate(ny):
            if plim is None or max(y) > plim:
                ax.plot(t, y, '-', label=svars[e])
            else:
                ax.plot(t, y, '--', lw=0.1, color='gray', zorder=1)

    plt.title(name)
    if xlim:
        plt.xlim(xlim)
    # plt.xticks(np.arange(0, 61, step=20))

    if ylim:
        plt.ylim(ylim)
    # plt.yticks(np.arange(0, 51, step=10))

    ax.set_xlabel('Time', fontsize=16)
    ax.set_ylabel('Concentration', fontsize=16)
    if log:
        ax.set_xscale('log')
    else:
        ax.set_xscale('linear')

    plt.legend(loc='upper right')
    fig.tight_layout()
    plt.savefig(name)
    plt.close()
    return name


rates = {
    #<&>RATES<&>#
}

#<&>ODECALL<&>#

#<&>JACOBIAN<&>#


def add_integrator_args(parser):
    """ODE integration aruments."""
    solver = parser.add_argument_group('odeint parameters')
    plotter = parser.add_argument_group('plotting parameters')

    # required: simulation time and output settings
    solver.add_argument("--t0", type=float, default=0, metavar='<flt>',
            help="First time point of the time-course.")
    solver.add_argument("--t8", type=float, default=100, metavar='<flt>',
            help="End point of simulation time.")
    plotter.add_argument("--t-lin", type=int, default=500, metavar='<int>',
            help="Returns --t-lin evenly spaced numbers on a linear scale from --t0 to --t8.")
    plotter.add_argument("--t-log", type=int, default=None, metavar='<int>',
            help="Returns --t-log evenly spaced numbers on a logarithmic scale from --t0 to --t8.")

    # required: initial concentration vector
    solver.add_argument("--p0", nargs='+', metavar='<int/str>=<flt>',
            help="""Vector of initial species concentrations. 
            E.g. \"--p0 1=0.5 3=0.7\" stands for 1st species at a concentration of 0.5 
            and 3rd species at a concentration of 0.7. You may chose to address species
            directly by name, e.g.: --p0 C=0.5.""")
    # advanced: scipy.integrate.odeint parameters
    solver.add_argument("-a", "--atol", type=float, default=None, metavar='<flt>',
            help="Specify absolute tolerance for the solver.")
    solver.add_argument("-r", "--rtol", type=float, default=None, metavar='<flt>',
            help="Specify relative tolerance for the solver.")
    solver.add_argument("--mxstep", type=int, default=0, metavar='<int>',
            help="Maximum number of steps allowed for each integration point in t.")

    # optional: choose output formats
    plotter.add_argument("--list-labels", action='store_true',
            help="Print all species and exit.")
    plotter.add_argument("--labels", nargs='+', default=[], metavar='<str>+',
            help="""Specify the (order of) species which should appear in the pyplot legend, 
            as well as the order of species for nxy output format.""")
    plotter.add_argument("--labels-strict", action='store_true',
            help="""When using pyplot, only plot tracjectories corresponding to labels,
            when using nxy, only print the trajectories corresponding to labels.""")
 
    plotter.add_argument("--nxy", action='store_true',
            help="Print time course to STDOUT in nxy format.")
    plotter.add_argument("--header", action='store_true',
            help="Print header for trajectories.")

    plotter.add_argument("--pyplot", default='', metavar='<str>',
            help="Specify a filename to plot the ODE simulation.")
    plotter.add_argument("--pyplot-xlim", nargs=2, type=float, default=None, metavar='<flt>',
            help="Specify the limits of the x-axis.")
    plotter.add_argument("--pyplot-ylim", nargs=2, type=float, default=None, metavar='<flt>',
            help="Specify the limits of the y-axis.")
    plotter.add_argument("--pyplot-labels", nargs='+', default=[], metavar='<str>+',
            help=argparse.SUPPRESS)
    return

def flint(inp):
    return int(inp) if float(inp) == int(float(inp)) else float(inp)

def set_logger(verbose, logfile):
    # ~~~~~~~~~~~~~
    # Logging Setup 
    # ~~~~~~~~~~~~~
    logger.setLevel(logging.DEBUG)
    handler = logging.FileHandler(logfile) if logfile else logging.StreamHandler()
    if verbose == 0:
        handler.setLevel(logging.WARNING)
    elif verbose == 1:
        handler.setLevel(logging.INFO)
    elif verbose == 2:
        handler.setLevel(logging.DEBUG)
    elif verbose >= 3:
        handler.setLevel(logging.NOTSET)
    formatter = logging.Formatter('%(levelname)s - %(message)s')
    handler.setFormatter(formatter)
    logger.addHandler(handler)


def integrate(args, setlogger = False):
    """Main interface to solve the ODE-system.

    Args:
      args (:obj:`argparse.ArgumentParser()`): An argparse object containing all of
        the arguments of :obj:`crnsimulator.add_integrator_args()`.

    Prints:
      - plot files
      - time-course

    Returns:
      Nothing
    """
    if setlogger:
        set_logger(args.verbose, args.logfile)

    if args.pyplot_labels:
        logger.warning('Deprecated argument: --pyplot_labels.')

    #<&>SORTEDVARS<&>#

    p0 = [0] * len(svars)
    #<&>DEFAULTCONCENTRATIONS<&>#
    const = None
    #<&>CONSTANT_SPECIES_INFO<&>#
    if args.p0:
        for term in args.p0:
            p, o = term.split('=')
            try:
                pi = svars.index(p)
            except ValueError as e:
                pi = int(p) - 1
            finally:
                p0[pi] = flint(o)
    else:
        msg = 'Specify a vector of initial concentrations: ' + \
                'e.g. --p0 1=0.1 2=0.005 3=1e-6 (see --help)'
        if sum(p0) == 0:
            logger.warning(msg)
            args.list_labels = True
        else:
            logger.info(msg)

    if args.list_labels:
        print('List of variables and initial concentrations:')
        for e, v in enumerate(svars, 1):
            if args.labels_strict and e > len(args.labels):
                break
            print(f'{e} {v} {p0[e-1]} {"constant" if const and const[e-1] else ""}')
        raise SystemExit('Initial concentrations can be overwritten by --p0 argument')

    if not args.nxy and not args.pyplot:
        logger.warning('Use --pyplot and/or --nxy to plot your results.')

    if not args.t8:
        raise ODETemplateError('Specify a valid end-time for the simulation: --t8 <flt>')

    if args.t_log:
        if args.t0 == 0:
            raise ODETemplateError('--t0 cannot be 0 when using log-scale!')
        time = np.logspace(np.log10(args.t0), np.log10(args.t8), num=args.t_log)
    elif args.t_lin:
        time = np.linspace(args.t0, args.t8, num=args.t_lin)
    else:
        raise ODETemplateError('Please specify either --t-lin or --t-log. (see --help)')

    # It would be nice if it is possible to read alternative rates from a file instead.
    # None triggers the default-rates that are hard-coded in the (this) library file.
    rates = None

    logger.info(f'Initial concentrations: {list(zip(svars, p0))}')
    # TODO: logging should report more info on parameters.

    ny = odeint(#<&>ODENAME<&>#,
        np.array(p0), time, (rates, ), #<&>JCALL<&>#,
        atol=args.atol, rtol=args.rtol, mxstep=args.mxstep).T

    # Output
    if args.nxy and args.labels_strict:
        end = len(args.labels)
        if args.header:
            print(' '.join(['{:15s}'.format(x) for x in ['time'] + svars[:end]]))
        for i in zip(time, *ny[:end]):
            print(' '.join(map("{:.9e}".format, i)))
    elif args.nxy:
        if args.header:
            print(' '.join(['{:15s}'.format(x) for x in ['time'] + svars]))
        for i in zip(time, *ny):
            print(' '.join(map("{:.9e}".format, i)))

    if args.pyplot:
        plotfile = ode_plotter(args.pyplot, time, ny, svars,
                               log=True if args.t_log else False,
                               labels=set(args.labels),
                               xlim = args.pyplot_xlim,
                               ylim = args.pyplot_ylim,
                               labels_strict = args.labels_strict)
        logger.info(f"Plotting successfull. Wrote plot to file: {plotfile}")

    return zip(time, *ny)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("-v", "--verbose", action='count', default = 0,
        help = "Print logging output. (-vv increases verbosity.)")
    parser.add_argument('--logfile', default = '', action = 'store', metavar = '<str>',
        help = """Redirect logging information to a file.""")
    add_integrator_args(parser)
    args = parser.parse_args()
    integrate(args, setlogger = True)

