#!/usr/bin/env python3
#
# Tests the EasyML module.
#
# This file is part of Myokit.
# See http://myokit.org for copyright, sharing, and licensing details.
#
from __future__ import absolute_import, division
from __future__ import print_function, unicode_literals

import os
import unittest

import myokit
import myokit.formats
import myokit.formats.easyml

from shared import TemporaryDirectory, WarningCollector, DIR_DATA

# Unit testing in Python 2 and 3
try:
    unittest.TestCase.assertRaisesRegex
except AttributeError:  # pragma: no python 3 cover
    unittest.TestCase.assertRaisesRegex = unittest.TestCase.assertRaisesRegexp

# Strings in Python 2 and 3
try:
    basestring
except NameError:   # pragma: no python 2 cover
    basestring = str

# Model that requires unit conversion
units_model = """
[[model]]
membrane.V = -0.08
hh.x = 0.1
hh.y = 0.9
mm.C = 0.9

[engine]
time = 0 [s]
    in [s]
    bind time

[membrane]
dot(V) = (hh.I1 + mm.I2) / C
    in [V]
C = 20 [pF]
    in [pF]

[hh]
dot(x) = (inf - x) / tau
    inf = 0.8
    tau = 3 [s]
        in [s]
dot(y) = alpha * (1 - y) - beta * y
    alpha = 0.1 [1/s]
        in [1/s]
    beta = 0.2 [1/s]
        in [1/s]
I1 = 3 [pS] * x * y * (membrane.V - 0.05 [V])
    in [pA]

[mm]
dot(C) = beta * O - alpha * C
alpha = 0.3 [1/s]
    in [1/s]
beta = 0.4 [1/s]
    in [1/s]
O = 1 - C
I2 = 2 [pS] * O * (membrane.V + 0.02 [V])
    in [pA]
"""
units_output = """
/*
This file was generated by Myokit.
*/

V; .nodal(); .external(Vm);
Iion; .nodal(); .external();

V_init = -80.0;
x_init = 0.1;
y_init = 0.9;
C_init = 0.9;

// hh
I1 = 3.0 * x * y * (V / 1000.0 - 0.05) * 0.05;
x_inf = 0.8;
tau_x = 3.0 * 1000.0;
alpha_y = 0.1 * 0.001;
beta_y = 0.2 * 0.001;

// mm
diff_C = (mm_beta * O - mm_alpha * C) * 0.001;
I2 = 2.0 * O * (V / 1000.0 + 0.02) * 0.05;
O = 1.0 - C;
mm_alpha = 0.3;
mm_beta = 0.4;

// Sum of currents
Iion = I1 + I2;

// Markov model
group {
  C;
}.method(markov_be);

// Trace all currents and state variables
group {
  I1;
  I2;
  V;
  x;
  y;
  C;
}.trace();

// Parameters
group {
  mm_alpha;
  mm_beta;
}.param();
"""


class EasyMLExporterTest(unittest.TestCase):
    """ Tests EasyML export. """

    def test_easyml_exporter(self):
        # Tests exporting a model

        model = myokit.load_model('example')
        with TemporaryDirectory() as d:
            path = d.path('easy.model')

            # Test with simple model
            e = myokit.formats.easyml.EasyMLExporter()
            e.model(path, model)

            # Test with extra bound variables
            model.get('membrane.C').set_binding('hello')
            e.model(path, model)

            # Test without V being a state variable
            v = model.get('membrane.V')
            v.demote()
            v.set_rhs(3)
            e.model(path, model)

            # Test with invalid model
            v.set_rhs('2 * V')
            self.assertRaisesRegex(
                myokit.ExportError, 'valid model', e.model, path, model)

    def test_easyml_exporter_static(self):
        # Tests exporting a model (with HH and markov states) and compares
        # against reference output.

        # Export model
        m = myokit.load_model(os.path.join(DIR_DATA, 'decker-2009.mmt'))
        e = myokit.formats.easyml.EasyMLExporter()
        with TemporaryDirectory() as d:
            path = d.path('decker.model')
            e.model(path, m)
            with open(path, 'r') as f:
                observed = f.readlines()

        # Load expected output
        with open(os.path.join(DIR_DATA, 'decker.model'), 'r') as f:
            expected = f.readlines()

        # Compare (line by line, for readable output)
        for ob, ex in zip(observed, expected):
            self.assertEqual(ob, ex)
        self.assertEqual(len(observed), len(expected))

    def test_unit_conversion(self):
        # Tests exporting a model that requires unit conversion

        # Export model
        m = myokit.parse_model(units_model)
        e = myokit.formats.easyml.EasyMLExporter()
        with TemporaryDirectory() as d:
            path = d.path('easy.model')
            e.model(path, m)
            with open(path, 'r') as f:
                observed = f.read().strip().splitlines()

        # Get expected output
        expected = units_output.strip().splitlines()

        # Compare (line by line, for readable output)
        for ob, ex in zip(observed, expected):
            self.assertEqual(ob, ex)
        self.assertEqual(len(observed), len(expected))

        # Test warnings are raised if conversion fails
        m.get('membrane.V').set_rhs('hh.I1 + mm.I2')
        m.get('membrane').remove_variable(m.get('membrane.C'))
        with TemporaryDirectory() as d:
            path = d.path('easy.model')
            with WarningCollector() as c:
                e.model(path, m)
            self.assertIn('Unable to convert hh.I1', c.text())
            self.assertIn('Unable to convert mm.I2', c.text())

        m.get('engine.time').set_unit(myokit.units.cm)
        with TemporaryDirectory() as d:
            path = d.path('easy.model')
            with WarningCollector() as c:
                e.model(path, m)
            self.assertIn('Unable to convert time units [cm]', c.text())

    def test_export_reused_variable(self):
        # Tests exporting when an `inf` or other special variable is used twice

        # Create model re-using tau and inf
        m = myokit.parse_model(
            """
            [[model]]
            m.V = -80
            c.x = 0.1
            c.y = 0.1

            [m]
            time = 0 bind time
            i_ion = c.I
            dot(V) = -i_ion

            [c]
            inf = 0.5
            tau = 3
            dot(x) = (inf - x) / tau
            dot(y) = (inf - y) / tau
            I = x * y * (m.V - 50)
            """)

        # Export, and read back in
        e = myokit.formats.easyml.EasyMLExporter()
        with TemporaryDirectory() as d:
            path = d.path('easy.model')
            e.model(path, m)
            with open(path, 'r') as f:
                x = f.read()

        self.assertIn('x_inf =', x)
        self.assertIn('y_inf =', x)
        self.assertIn('tau_x =', x)
        self.assertIn('tau_y =', x)

    def test_easyml_exporter_fetching(self):
        # Tests getting an EasyML exporter via the 'exporter' interface

        e = myokit.formats.exporter('easyml')
        self.assertIsInstance(e, myokit.formats.easyml.EasyMLExporter)

    def test_capability_reporting(self):
        # Tests if the correct capabilities are reported
        e = myokit.formats.easyml.EasyMLExporter()
        self.assertTrue(e.supports_model())


class EasyMLExpressionWriterTest(unittest.TestCase):
    """ Tests EasyML expression writer functionality. """

    def test_all(self):
        w = myokit.formats.ewriter('easyml')

        model = myokit.Model()
        component = model.add_component('c')
        avar = component.add_variable('a')

        # Name
        a = myokit.Name(avar)
        self.assertEqual(w.ex(a), 'c.a')
        # Number with unit
        b = myokit.Number('12', 'pF')
        self.assertEqual(w.ex(b), '12.0')
        # Integer
        c = myokit.Number(1)
        self.assertEqual(w.ex(c), '1.0')
        # Integer

        # Prefix plus
        x = myokit.PrefixPlus(b)
        self.assertEqual(w.ex(x), '12.0')
        # Prefix minus
        x = myokit.PrefixMinus(b)
        self.assertEqual(w.ex(x), '(-12.0)')

        # Plus
        x = myokit.Plus(a, b)
        self.assertEqual(w.ex(x), 'c.a + 12.0')
        # Minus
        x = myokit.Minus(a, b)
        self.assertEqual(w.ex(x), 'c.a - 12.0')
        # Multiply
        x = myokit.Multiply(a, b)
        self.assertEqual(w.ex(x), 'c.a * 12.0')
        # Divide
        x = myokit.Divide(a, b)
        self.assertEqual(w.ex(x), 'c.a / 12.0')

        # 1 - exp() and exp() - 1
        x = myokit.Minus(myokit.Exp(myokit.Number(2)), myokit.Number(1))
        self.assertEqual(w.ex(x), 'expm1(2.0)')
        x = myokit.Minus(myokit.Number(1), myokit.Exp(myokit.Number(3)))
        self.assertEqual(w.ex(x), '-expm1(3.0)')

        # Quotient
        x = myokit.Quotient(a, b)
        with WarningCollector() as c:
            self.assertEqual(w.ex(x), 'floor(c.a / 12.0)')
        # Remainder
        x = myokit.Remainder(a, b)
        with WarningCollector() as c:
            self.assertEqual(w.ex(x), 'c.a - 12.0 * (floor(c.a / 12.0))')

        # Power
        x = myokit.Power(a, b)
        self.assertEqual(w.ex(x), 'pow(c.a, 12.0)')
        # Sqrt
        x = myokit.Sqrt(b)
        self.assertEqual(w.ex(x), 'sqrt(12.0)')
        # Exp
        x = myokit.Exp(a)
        self.assertEqual(w.ex(x), 'exp(c.a)')
        # Log(a)
        x = myokit.Log(b)
        self.assertEqual(w.ex(x), 'log(12.0)')
        # Log(a, b)
        x = myokit.Log(a, b)
        self.assertEqual(w.ex(x), '(log(c.a) / log(12.0))')
        # Log10
        x = myokit.Log10(b)
        self.assertEqual(w.ex(x), 'log10(12.0)')

        # Sin
        with WarningCollector() as c:
            x = myokit.Sin(b)
            self.assertEqual(w.ex(x), 'sin(12.0)')
            # Cos
            x = myokit.Cos(b)
            self.assertEqual(w.ex(x), 'cos(12.0)')
            # Tan
            x = myokit.Tan(b)
            self.assertEqual(w.ex(x), 'tan(12.0)')
            # ASin
            x = myokit.ASin(b)
            self.assertEqual(w.ex(x), 'asin(12.0)')
            # ACos
            x = myokit.ACos(b)
            self.assertEqual(w.ex(x), 'acos(12.0)')
            # ATan
            x = myokit.ATan(b)
            self.assertEqual(w.ex(x), 'atan(12.0)')

        with WarningCollector() as c:
            # Floor
            x = myokit.Floor(b)
            self.assertEqual(w.ex(x), 'floor(12.0)')
            # Ceil
            x = myokit.Ceil(b)
            self.assertEqual(w.ex(x), 'ceil(12.0)')
            # Abs
            x = myokit.Abs(b)
            self.assertEqual(w.ex(x), 'fabs(12.0)')

        # Equal
        x = myokit.Equal(a, b)
        self.assertEqual(w.ex(x), '(c.a == 12.0)')
        # NotEqual
        x = myokit.NotEqual(a, b)
        self.assertEqual(w.ex(x), '(c.a != 12.0)')
        # More
        x = myokit.More(a, b)
        self.assertEqual(w.ex(x), '(c.a > 12.0)')
        # Less
        x = myokit.Less(a, b)
        self.assertEqual(w.ex(x), '(c.a < 12.0)')
        # MoreEqual
        x = myokit.MoreEqual(a, b)
        self.assertEqual(w.ex(x), '(c.a >= 12.0)')
        # LessEqual
        x = myokit.LessEqual(a, b)
        self.assertEqual(w.ex(x), '(c.a <= 12.0)')

        # Not
        cond1 = myokit.parse_expression('5 > 3')
        cond2 = myokit.parse_expression('2 < 1')
        x = myokit.Not(cond1)
        self.assertEqual(w.ex(x), '!((5.0 > 3.0))')
        # And
        x = myokit.And(cond1, cond2)
        self.assertEqual(w.ex(x), '((5.0 > 3.0) and (2.0 < 1.0))')
        # Or
        x = myokit.Or(cond1, cond2)
        self.assertEqual(w.ex(x), '((5.0 > 3.0) or (2.0 < 1.0))')

        # If
        x = myokit.If(cond1, a, b)
        self.assertEqual(w.ex(x), '((5.0 > 3.0) ? c.a : 12.0)')
        # Piecewise
        c = myokit.Number(1)
        x = myokit.Piecewise(cond1, a, cond2, b, c)
        self.assertEqual(
            w.ex(x),
            '((5.0 > 3.0) ? c.a : ((2.0 < 1.0) ? 12.0 : 1.0))')

        # Test without a Myokit expression
        self.assertRaisesRegex(
            ValueError, 'Unknown expression type', w.ex, 7)

    def test_easyml_ewriter_fetching(self):

        # Test fetching using ewriter method
        w = myokit.formats.ewriter('easyml')
        self.assertIsInstance(w, myokit.formats.easyml.EasyMLExpressionWriter)


if __name__ == '__main__':
    unittest.main()

