import numpy as np
from pytest import raises
from scipy.integrate import ode

from pydmd import HAVOK


def lorenz_system(t, state, par):
    """
    Defines the system of differential equations y'(t) = f(t, y, params)
    """
    x, y, z = state
    sigma, rho, beta = par
    x_dot = sigma * (y - x)
    y_dot = (x * (rho - z)) - y
    z_dot = (x * y) - (beta * z)
    return np.array((x_dot, y_dot, z_dot))


def generate_lorenz_data(t):
    """
    Given a time vector t = t1, t2, ..., evaluates and returns the snapshots
    of the Lorenz system as columns of the matrix X via explicit Runge-Kutta.
    """
    # Chaotic Lorenz parameters
    sigma, rho, beta = 10, 28, 8 / 3

    # Initial condition
    initial = np.array((-8, 8, 27))

    # Generate Lorenz data
    X = np.empty((3, len(t)))
    X[:, 0] = initial
    r = ode(lorenz_system).set_integrator("dopri5")
    r.set_initial_value(initial, t[0])
    r.set_f_params((sigma, rho, beta))
    for i, ti in enumerate(t):
        if i == 0:
            continue
        r.integrate(ti)
        X[:, i] = r.y

    return X


# Generate chaotic Lorenz System data
dt = 0.001
t = np.arange(0, 100, dt)
lorenz_xyz = generate_lorenz_data(t)
lorenz_x = lorenz_xyz[0]


def test_shape():
    """
    Using the default HAVOK parameters, checks that the shapes of
    linear_embeddings, forcing_input, A, and B are accurate.
    """
    havok = HAVOK()
    havok.fit(lorenz_x, dt)
    assert havok.linear_embeddings.shape == (len(t) - havok.d + 1, havok.r - 1)
    assert havok.forcing_input.shape == (len(t) - havok.d + 1,)
    assert havok.A.shape == (havok.r - 1, havok.r - 1)
    assert havok.B.shape == (havok.r - 1, 1)


def test_error_fitted():
    """
    Ensure that attempting to get the attributes linear_embeddings,
    forcing_input, A, B, or r results in a RuntimeError if fit()
    has not yet been called.
    """
    havok = HAVOK()
    with raises(RuntimeError):
        _ = havok.linear_embeddings
    with raises(RuntimeError):
        _ = havok.forcing_input
    with raises(RuntimeError):
        _ = havok.A
    with raises(RuntimeError):
        _ = havok.B
    with raises(RuntimeError):
        _ = havok.r


def test_error_1d():
    """
    Ensure that the fit function will reject data that isn't one-dimensional.
    """
    havok = HAVOK()
    with raises(ValueError):
        havok.fit(lorenz_xyz, dt)


def test_error_reconstructions_of_timeindex():
    """
    Ensure that calling reconstructions_of_timeindex results in an error.
    """
    havok = HAVOK()
    with raises(NotImplementedError):
        havok.reconstructions_of_timeindex()


def test_error_small_r():
    """
    Ensure that a runtime error is thrown if r is too small.
    """
    havok = HAVOK(d=1)
    with raises(RuntimeError):
        havok.fit(lorenz_x, dt)


def test_r():
    """
    Ensure the accuracy of the r property in the following situations:
    """
    # If no svd truncation, r is the min of the dimensions of the hankel matrix
    havok = HAVOK(svd_rank=-1)
    havok.fit(lorenz_x, dt)
    assert havok.r == min(havok.d, len(t) - havok.d + 1)

    # Test the above case, but for a larger d value
    havok = HAVOK(svd_rank=-1, d=500)
    havok.fit(lorenz_x, dt)
    assert havok.r == min(havok.d, len(t) - havok.d + 1)

    # Test the above case, but for an even larger d value
    havok = HAVOK(svd_rank=-1, d=len(t) - 20)
    havok.fit(lorenz_x, dt)
    assert havok.r == min(havok.d, len(t) - havok.d + 1)

    # If given a positive integer svd truncation, r should equal svd_rank
    havok = HAVOK(svd_rank=3)
    havok.fit(lorenz_x, dt)
    assert havok.r == havok.operator._svd_rank


def test_reconstruction():
    """
    Test the accuracy of the HAVOK reconstruction. Note that the parameters
    used here have been successful in reconstructing the Lorenz System.
    """
    havok = HAVOK(svd_rank=15, d=100)
    havok.fit(lorenz_x, dt)
    error = lorenz_x - havok.reconstructed_data.real
    error_norm = np.linalg.norm(error) / np.linalg.norm(lorenz_x)
    assert error_norm < 0.2
