"""
Author: Damien GUEHO
Copyright: Copyright (C) 2021 Damien GUEHO
License: Public Domain
Version: 19
Date: November 2021
Python: 3.7.7
"""



import numpy as np


from ClassesDynamics.ClassSystemWithAStableOriginDynamics import SystemWithAStableOriginDynamics
from ClassesDynamics.ClassPointMassInRotatingTubeDynamics import PointMassInRotatingTubeDynamics
from ClassesDynamics.ClassMassSpringDamperDynamics import MassSpringDamperDynamics
from ClassesGeneral.ClassSystem import DiscreteLinearSystem
from SystemIDAlgorithms.GetOptimizedHankelMatrixSize import getOptimizedHankelMatrixSize
from ClassesGeneral.ClassSignal import DiscreteSignal, OutputSignal, subtract2Signals, addSignals
from ClassesGeneral.ClassExperiments import Experiments
from Plotting.PlotSignals import plotSignals
from Plotting.PlotEigenValues import plotHistoryEigenValues2Systems
from ClassesSystemID.ClassMarkovParameters import TVOKIDWithObserver
from ClassesSystemID.ClassERA import TVERA, TVERADC, TVERADCFromInitialConditionResponse, TVERAFromInitialConditionResponse
from SystemIDAlgorithms.CorrectSystemForEigenvaluesCheck import correctSystemForEigenvaluesCheck
from SystemIDAlgorithms.GetTimeVaryingObserverGainMatrix import getTimeVaryingObserverGainMatrix



# Dynamics
dt = 0.1
mass = 1
spring_constant = 10
damping_coefficient = 0.1
def theta_dot(t):
    return 3 * np.sin(t / 2)
# dynamics = MassSpringDamperDynamics(dt, mass, spring_constant, damping_coefficient, ['mass'], ['position', 'velocity'])
dynamics = PointMassInRotatingTubeDynamics(dt, mass, spring_constant, theta_dot)

# Initial Condition
x0 = np.random.randn(dynamics.state_dimension) * 0.1


# Frequency and total time
frequency = 10
total_time = 50
number_steps = total_time * frequency + 1


# System
nominal_system = DiscreteLinearSystem(frequency, dynamics.state_dimension, dynamics.input_dimension, dynamics.output_dimension, [(x0, 0)], 'Nominal System', dynamics.A, dynamics.B, dynamics.C, dynamics.D)


# Parameters for identification
assumed_order = 3
p, q = getOptimizedHankelMatrixSize(assumed_order, dynamics.output_dimension, dynamics.input_dimension)
deadbeat_order = 10


# Free Decay Experiments
number_batches = 5
number_free_decay_experiments = q * dynamics.input_dimension * 2
number_free_decay_experiments_total = number_free_decay_experiments * number_batches
free_decay_systems = []
free_decay_input_signals = []
free_decay_experiments_batch = []
for l in range(number_batches):
    free_decay_systems_batch = []
    free_decay_input_signals_batch = []
    for i in range(number_free_decay_experiments):
        initial_state = np.random.randn(dynamics.state_dimension) * 0.1
        free_decay_systems.append(DiscreteLinearSystem(frequency, dynamics.state_dimension, dynamics.input_dimension, dynamics.output_dimension, [(initial_state, 0)], 'Free Decay Experiments System ' + str(i), dynamics.A, dynamics.B, dynamics.C, dynamics.D))
        free_decay_systems_batch.append(DiscreteLinearSystem(frequency, dynamics.state_dimension, dynamics.input_dimension, dynamics.output_dimension, [(initial_state, 0)], 'Free Decay Experiments System ' + str(i), dynamics.A, dynamics.B, dynamics.C, dynamics.D))
        free_decay_input_signals.append(DiscreteSignal(dynamics.input_dimension, total_time, frequency))
        free_decay_input_signals_batch.append(DiscreteSignal(dynamics.input_dimension, total_time, frequency))
    free_decay_experiments_batch.append(Experiments(free_decay_systems_batch, free_decay_input_signals_batch))
free_decay_experiments = Experiments(free_decay_systems, free_decay_input_signals)


## Add noise
for i in range(number_free_decay_experiments_total):
    print(i)
    noise = DiscreteSignal(dynamics.output_dimension, total_time, frequency, signal_shape='White Noise', covariance=0.02 * np.eye(dynamics.output_dimension))
    free_decay_experiments.output_signals[i] = addSignals([free_decay_experiments.output_signals[i], noise])
    free_decay_experiments_batch[int(np.floor(i / number_free_decay_experiments))].output_signals[i % number_free_decay_experiments] = addSignals([free_decay_experiments_batch[int(np.floor(i / number_free_decay_experiments))].output_signals[i % number_free_decay_experiments], noise])



# Full Experiment
full_system = [nominal_system]
full_input_signal = [DiscreteSignal(dynamics.input_dimension, total_time, frequency)]
full_experiment = Experiments(full_system, full_input_signal)
# full_experiment.output_signals[0] = addSignals([full_experiment.output_signals[0], DiscreteSignal(dynamics.output_dimension, total_time, frequency, signal_shape='White Noise', covariance=0.1 * np.eye(dynamics.output_dimension))])



# TVERA
# tvera = TVERA(free_decay_experiments, okid.hki, okid.D, full_experiment, dynamics.state_dimension, p, q, apply_transformation=True)
tvera_ic = TVERAFromInitialConditionResponse(free_decay_experiments, full_experiment, dynamics.state_dimension, p=10)
tveradc_ic = TVERADCFromInitialConditionResponse(free_decay_experiments_batch, full_experiment, dynamics.state_dimension, p=10, xi=4, tau=10)


# Identified System
# system_id_tvera = DiscreteLinearSystem(frequency, dynamics.state_dimension, dynamics.input_dimension, dynamics.output_dimension, [(tvera.x0, 0)], 'System ID', tvera.A, tvera.B, tvera.C, tvera.D)
system_id_tvera_ic = DiscreteLinearSystem(frequency, dynamics.state_dimension, dynamics.input_dimension, dynamics.output_dimension, [(tvera_ic.x0, 0)], 'System ID', tvera_ic.A, tvera_ic.B, tvera_ic.C, tvera_ic.D)
system_id_tveradc_ic = DiscreteLinearSystem(frequency, dynamics.state_dimension, dynamics.input_dimension, dynamics.output_dimension, [(tveradc_ic.x0, 0)], 'System ID', tveradc_ic.A, tveradc_ic.B, tveradc_ic.C, tveradc_ic.D)
# G = getTimeVaryingObserverGainMatrix(tvera.A, tvera.C, okid.hkio, p + 2, 1 / frequency)
# system_id_observer = DiscreteLinearSystem(frequency, dynamics.state_dimension, dynamics.input_dimension, dynamics.output_dimension, [(np.zeros(dynamics.state_dimension), 0)], 'System ID', tvera.A, tvera.B, tvera.C, tvera.D, observer_gain=G)
# system_id_no_observer = DiscreteLinearSystem(frequency, dynamics.state_dimension, dynamics.input_dimension, dynamics.output_dimension, [(np.zeros(dynamics.state_dimension), 0)], 'System ID', tvera.A, tvera.B, tvera.C, tvera.D)



# Test Signal
test_signal = DiscreteSignal(dynamics.input_dimension, total_time, frequency)


# True Output
output_signal = OutputSignal(test_signal, nominal_system)
# output_signal = full_experiment.output_signals[0]


# Identified Output
# output_signal_tvera = OutputSignal(test_signal, system_id_tvera)
output_signal_tvera_ic = OutputSignal(test_signal, system_id_tvera_ic)
output_signal_tveradc_ic = OutputSignal(test_signal, system_id_tveradc_ic)
# output_signal_id_no_observer = OutputSignal(test_signal, system_id_no_observer)
# output_signal_id_observer = OutputSignal(test_signal, system_id_observer, observer=True, reference_output_signal=output_signal_id)


# Plotting Output Signals
# plotSignals([[output_signal, output_signal_tvera_ic, output_signal_tveradc_ic], [subtract2Signals(output_signal, output_signal_tvera_ic), subtract2Signals(output_signal, output_signal_tveradc_ic)]], 2, percentage=0.9)


# True Corrected System
# corrected_system = correctSystemForEigenvaluesCheck(nominal_system, full_input_signal[0].number_steps - q, p)


# Identified Corrected System
# corrected_system_tvera = correctSystemForEigenvaluesCheck(system_id_tvera, full_input_signal[0].number_steps - q, p)
# corrected_system_tveradc = correctSystemForEigenvaluesCheck(system_id_tveradc, full_input_signal[0].number_steps - q, p)


# Plot Eigenvalues
# plotHistoryEigenValues2Systems([corrected_system, corrected_system_tvera], full_input_signal[0].number_steps - q, 3)
# plotHistoryEigenValues2Systems([corrected_system, corrected_system_tveradc], full_input_signal[0].number_steps - q, 3)





## Plotting

import matplotlib.pyplot as plt
from matplotlib import rc
colors = [(11/255, 36/255, 251/255),
          (27/255, 161/255, 252/255),
          (77/255, 254/255, 193/255),
          (224/255, 253/255, 63/255),
          (253/255, 127/255, 35/255),
          (221/255, 10/255, 22/255),
          (255/255, 0/255, 127/255),
          (127/255, 0/255, 255/255),
          (255/255, 0/255, 255/255),
          (145/255, 145/255, 145/255),
          (0, 0, 0)]


plt.rcParams.update({"text.usetex": True, "font.family": "sans-serif", "font.serif": ["Computer Modern Roman"]})
rc('text', usetex=True)

end = 10
tspan_test = np.linspace(0, total_time - end, number_steps - end * frequency)

fig = plt.figure(1, figsize=(12, 6))
ax = plt.subplot(2, 2, 1)
ax.plot(tspan_test, output_signal.data[0, :-end * frequency], color=colors[0], label='True')
ax.plot(tspan_test, output_signal_tvera_ic.data[0, :-end * frequency], color=colors[5], label='TVERA')
ax.plot(tspan_test, output_signal_tveradc_ic.data[0, :-end * frequency], color=colors[7], label='TVERA/DC')
plt.xlabel(r'Time')
plt.ylabel(r'$y_1$')
ax.legend(loc='lower left')

ax = plt.subplot(2, 2, 2)
ax.plot(tspan_test, output_signal.data[1, :-end * frequency], color=colors[0], label='True')
ax.plot(tspan_test, output_signal_tvera_ic.data[1, :-end * frequency], color=colors[5], label='TVERA')
ax.plot(tspan_test, output_signal_tveradc_ic.data[1, :-end * frequency], color=colors[7], label='TVERA/DC')
plt.xlabel(r'Time')
plt.ylabel(r'$y_2$')
ax.legend(loc='lower left')

ax = plt.subplot(2, 2, 3)
ax.semilogy(tspan_test, np.abs(output_signal.data[0, :-end * frequency] - output_signal_tvera_ic.data[0, :-end * frequency]), color=colors[5], label='TVERA')
ax.semilogy(tspan_test, np.abs(output_signal.data[0, :-end * frequency] - output_signal_tveradc_ic.data[0, :-end * frequency]), color=colors[7], label='TVERA/DC')
plt.xlabel(r'Time')
plt.ylabel(r'Error $y_1$')
ax.legend(loc='lower left')

ax = plt.subplot(2, 2, 4)
ax.semilogy(tspan_test, np.abs(output_signal.data[1, :-end * frequency] - output_signal_tvera_ic.data[1, :-end * frequency]), color=colors[5], label='TVERA')
ax.semilogy(tspan_test, np.abs(output_signal.data[1, :-end * frequency] - output_signal_tveradc_ic.data[1, :-end * frequency]), color=colors[7], label='TVERA/DC')
plt.xlabel(r'Time')
plt.ylabel(r'Error $y_2$')
ax.legend(loc='lower left')
plt.tight_layout()
# plt.savefig('TVERADC_IC_1.eps', format='eps')
plt.show()


## RMSE
print('RMSE TVERA =', np.sqrt(np.mean(subtract2Signals(output_signal, output_signal_tvera_ic).data[:, :-100] ** 2)))
print('RMSE TVERA\DC =', np.sqrt(np.mean(subtract2Signals(output_signal, output_signal_tveradc_ic).data[:, :-100] ** 2)))
print('Absolute error TVERA =', np.mean(np.abs(subtract2Signals(output_signal, output_signal_tvera_ic).data[:, :-100])))
print('Absolute error TVERADC =', np.mean(np.abs(subtract2Signals(output_signal, output_signal_tveradc_ic).data[:, :-100])))