#
# Copyright 2021 Lars Pastewka (U. Freiburg)
#           2020, 2022 Thomas Reichenbach (Fraunhofer IWM)
#
# matscipy - Materials science with Python at the atomic-scale
# https://github.com/libAtoms/matscipy
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
from matplotlib import pyplot as pp
from matscipy import pressurecoupling as pc
from scipy.ndimage import uniform_filter
from ase.units import fs
import sys
import numpy as np
from io import open

if len(sys.argv) > 1:
    fn = sys.argv[1]
else:
    fn = 'log.txt'
if len(sys.argv) > 2:
    dt_mean = float(sys.argv[2]) * fs
else:
    dt_mean = 100.0 * fs
with open(fn, 'r', encoding='utf-8') as handle:
    log = pc.SlideLog(handle)

dt = (log.time[1] - log.time[0]) * fs


def cummean(a):
    return np.cumsum(a) / np.arange(1, len(log.time) + 1)


def leftmwmean(a, time):
    steps = int(np.around(time / dt))
    ret = uniform_filter(a, steps, mode="constant",
                         cval=0.0, origin=(steps - 1) // 2)
    ret[:steps - 1] = np.nan
    return ret


def plot(attrs, labels, dt_mean, ylabel):
    pp.figure()
    for attr, label in zip(attrs, labels):
        pp.plot(log.time, getattr(log, attr), ':', label=label + ' current')
    pp.gca().set_prop_cycle(None)
    for attr, label in zip(attrs, labels):
        pp.plot(log.time, cummean(getattr(log, attr)),
                '--', label=label + ' last $t$ mean')
    pp.gca().set_prop_cycle(None)
    for attr, label in zip(attrs, labels):
        pp.plot(log.time, leftmwmean(getattr(log, attr), dt_mean),
                '-', label=label + ' last %.0f fs mean' % (dt_mean / fs, ))
    pp.xlabel(r'$t / \mathrm{fs}$')
    pp.ylabel(ylabel)
    pp.legend(loc='upper left')
    pp.tight_layout()


plot(['T_thermostat'], ['thermostat dof'], dt_mean, r'$T / \mathrm{K}$')
plot(['P_top', 'P_bottom'], ['top', 'bottom'], dt_mean, r'$P / \mathrm{GPa}$')
plot(['tau_top', 'tau_bottom'], ['top', 'bottom'], dt_mean,
     r'$\tau / \mathrm{GPa}$')
plot(['h'], ['lid normal pos.'], dt_mean, r'$h / \mathrm{\AA}$')
plot(['v'], ['lid normal vel.'], dt_mean,
     r'$v_{\mathrm{normal}} / \mathrm{\AA}\mathrm{fs}^{-1}$')
plot(['a'], ['lid normal acc.'], dt_mean,
     r'$a_{\mathrm{normal}} / \mathrm{\AA}\mathrm{fs}^{-2}$')

# friction coefficient
pp.figure()
pp.plot(log.time, cummean(-log.tau_top) / cummean(log.P_top),
        '--', label='top last $t$ mean')
pp.plot(log.time, cummean(-log.tau_bottom) / cummean(log.P_bottom),
        '--', label='bottom last $t$ mean')
pp.gca().set_prop_cycle(None)
pp.plot(log.time, leftmwmean(-log.tau_top,
                             dt_mean) / leftmwmean(log.P_top, dt_mean),
        '-', label='top last %.0f fs mean' % (dt_mean / fs, ))
pp.plot(log.time, leftmwmean(-log.tau_bottom,
                             dt_mean) / leftmwmean(log.P_bottom, dt_mean),
        '-', label='bottom last %.0f fs mean' % (dt_mean / fs, ))
pp.xlabel(r'$t / \mathrm{fs}$')
pp.ylabel(r'$\mu$')
pp.legend(loc='upper left')
pp.tight_layout()

tau_mean = 0.5 * (leftmwmean(-log.tau_top, dt_mean)[-1]
                  + leftmwmean(log.tau_bottom, dt_mean)[-1])
print("Final tau last %.0f fs mean top/bottom: %.1f GPa"
      % (dt_mean / fs, tau_mean))

pp.show()
