#!/usr/bin/env python
# -*- coding: utf-8 -*-

from __future__ import division
__author__ = "Johan Hake (hake.dev@gmail.com)"
__date__ = "2013-03-13 -- 2014-05-14"
__copyright__ = "Copyright (C) 2013 " + __author__
__license__  = "GNU LGPL Version 3.0 or later"

from modelparameters.codegeneration import latex
from scipy.integrate import odeint
from itertools import cycle
import matplotlib.pyplot as plt
import numpy as np
import instant
from gotran.model.loadmodel import load_ode
from gotran.model.expressions import Expression
from gotran.common.options import parameters
from gotran.common import error
from gotran.codegeneration.compilemodule import compile_module

from goss.codegeneration import GossCodeGenerator
from goss.compilemodule import jit as goss_jit
from goss import goss_solvers
import goss

def main(filename, params):

    goss.set_log_level(getattr(goss, params.log_level))

    # Compile executeable code from gotran ode
    ode = load_ode(filename)

    # Get monitored and plot states
    plot_states = params.plot_y

    # Get x_values
    x_name = params.plot_x

    state_names = [state.name for state in ode.full_states]
    monitored_plot = [plot_states.pop(plot_states.index(name)) \
                      for name in plot_states[:] if name not in state_names]

    monitored = [expr.name for expr in ode.intermediates + ode.state_expressions]
    for mp in monitored_plot:
        if mp not in monitored:
            error("{} is not a state or intermediate in this ODE".format(mp))

    if x_name not in ["time"]+monitored+state_names:
        error("Expected plot_x to be either 'time' or one of the plotable "\
              "variables, got {}".format(x_name))

    # Logic if x_name is not 'time'
    if x_name != "time":
        if x_name in state_names:
            plot_states.append(x_name)
        else:
            monitored_plot.append(x_name)

    goss_ode = goss_jit(ode, monitored=monitored, code_params=params.code, \
                        cppargs=params.cppargs)

    parameter_values = params.parameters
    init_conditions = params.init_conditions

    if len(parameter_values) == 1 and parameter_values[0] == "":
        parameter_values = []

    if len(init_conditions) == 1 and init_conditions[0] == "":
        init_conditions = []

    if len(parameter_values) % 2 != 0:
        error("Expected an even number of values for 'parameters'")

    if len(init_conditions) % 2 != 0:
        error("Expected an even number of values for 'initial_conditions'")

    # Set parameters direct to the compiled goss ODE
    user_params = dict()
    for param_name, param_value in [(parameter_values[i*2], parameter_values[i*2+1]) \
                                    for i in range(int(len(parameter_values)/2))]:
        goss_ode.set_parameter(param_name, float(param_value))

    # Extract initial conditions
    user_ic = dict()
    for state_name, state_value in [(init_conditions[i*2], init_conditions[i*2+1]) \
                                    for i in range(int(len(init_conditions)/2))]:

        user_ic[state_name] = float(state_value)

    # FIXME: Update GOSS to do this more easily within GOSS
    # Create python module to set init values
    python_params = parameters.generation.copy()
    for name in python_params.functions:
        python_params.functions[name].generate = False

    module = compile_module(ode, language="Python", generation_params=python_params)

    # Get plot inds and initial conditions
    plot_inds = np.array([module.state_indices(state) for state in \
                          plot_states], dtype=int)
    monitor_inds = np.array([monitored.index(monitor) \
                             for monitor in monitored_plot], dtype=int)

    states = module.init_state_values(**user_ic)

    # Create the Solver
    solver = getattr(goss, params.solver)(goss_ode)

    solver_param = params.solver_parameters
    if len(solver_param) == 1 and solver_param[0] == "":
        solver_param = []

    if len(solver_param) % 2 != 0:
        error("Expected an even number of values for 'parameters'")

    for param_name, param_value in [(solver_param[i*2], solver_param[i*2+1]) \
                                    for i in range(int(len(solver_param)/2))]:
        if param_name not in solver.parameters:
            error("Parameter '{0}' is not a parameter in '{1}'".format(\
                param_name, params.solver))
        value_type = type(solver.parameters[param_name])
        solver.parameters[param_name] = value_type(param_value)

    # Use to integrate model
    t0 = 0.
    t1 = params.tstop
    dt = params.dt

    tsteps = np.linspace(t0, t1, t1/dt+1)

    # Create numpy array for monitored values
    monitored_values = np.zeros(len(monitored), dtype=np.float_)

    # Allocate memory
    collected_values = np.zeros((len(plot_states)+len(monitored_plot), len(tsteps)))

    # Save the initial conditions
    if plot_states:
        collected_values[:len(plot_states), 0] = states[plot_inds]
    if monitored_plot:
        goss_ode.eval_monitored(states, t0, monitored_values)
        collected_values[len(plot_states):, 0] = monitored_values[monitor_inds]

    p = goss.Progress("Stepping {0} with {1}".format(ode, params.solver), len(tsteps))

    # Integrate solution
    for ind, t in enumerate(tsteps[:-1]):

        # Step solver
        solver.forward(states, t, dt)

        # Collect plotstates and monitored
        if plot_states:
            collected_values[:len(plot_states), ind+1] = states[plot_inds]
        if monitored_plot:
            goss_ode.eval_monitored(states, t, monitored_values)
            collected_values[len(plot_states):, ind+1] = monitored_values[monitor_inds]

        p += 1

    del p

    if collected_values.shape[0] == 0:
        return

    # Plot data
    plt.rcParams["lines.linewidth"] = 2
    line_styles = cycle([c+s for s in ["-", "--", "-.", ":"]
                         for c in plt.rcParams["axes.color_cycle"]])

    plot_items = plot_states + monitored_plot
    if x_name != "time":
        x_values = collected_values[plot_items.index(x_name)]
    else:
        x_values = tsteps

    plotted_items = 0
    for what, values in zip(plot_items, collected_values):
        if what == x_name:
            continue
        plotted_items += 1
        plt.plot(x_values, values, next(line_styles))

    if plotted_items > 1:
        plt.legend(["$\\mathrm{{{0}}}$".format(latex(value))
                    for value in plot_items])
    else:
        plt.ylabel("$\\mathrm{{{0}}}$".format(latex(plot_items[0])))

    plt.xlabel("$\\mathrm{{{0}}}$".format(latex(x_name)))
    plt.title(ode.name.replace("_", "\\_"))
    plt.show()

if __name__ == "__main__":
    import sys, os
    from modelparameters.parameterdict import *

    code_params = GossCodeGenerator.default_parameters()

    params = ParameterDict(\
        solver = OptionParam("ImplicitEuler", goss_solvers, \
                             description="The ODE solver used to integrate "\
                             "the ODE/DAE."),
        solver_parameters = Param([""], description="Parameters passed to the solver"),
        log_level = OptionParam("PROGRESS", ["DEBUG", "PROGRESS", "INFO", "WARNING"]),
        parameters = Param([""], description="Set parameter of model"),
        init_conditions = Param([""], description="Set initial condition of model"),
        tstop = ScalarParam(100., gt=0, description="Time for stopping simulation"),\
        dt = ScalarParam(0.1, gt=0, description="Timestep for plotting."),\
        plot_y = Param(["V"], description="States or monitored to plot on the y axis."),\
        plot_x = Param("time", description="Values used for the x axis. Can be time "\
                       "and any valid plot_y variable."),
        code = code_params,
        cppargs = Param("-O2 -g"), description="C++ compile arguments.")

    params.parse_args(usage="usage: %prog FILE [options]")

    if len(sys.argv) < 2:
        error("Expected a single gotran file argument.")

    if not os.path.isfile(sys.argv[1]):
        error("Expected the argument to be a file.", exception=IOError)

    file_name = sys.argv[1]
    main(file_name, params)
