#!/usr/bin/env python
""" I hate this script and it should be rewritten, but
this should read data from a run3 convergence test and plot it.
"""
from matador.scrapers.castep_scrapers import castep2dict
from matador.plotting.plotting import plotting_function
from matador.utils.chem_utils import get_formula_from_stoich
from os import walk, chdir
from os.path import isdir
from decimal import Decimal, ROUND_UP
from collections import defaultdict
from traceback import print_exc
import numpy as np


def round(n, prec):
    """ Replace default (bankers) rounding with "normal" rounding."""
    if prec is None:
        return n
    else:
        return float(Decimal(str(n)).quantize(Decimal('0.05'), rounding=ROUND_UP))


def get_files(path, only=None):
    """ Find all CASTEP files in the directory. """
    chdir(path)
    structure_files = defaultdict(list)
    for root, dirs, files in walk('.', topdown=True, followlinks=True):
        for file in files:
            if only is None or only in file:
                if file.endswith('.castep'):
                    castep_dict, success = castep2dict(root + '/' + file, db=False)
                    if not success:
                        print('Failure to read castep')
                    else:
                        source = castep_dict['source'][0].split('/')[-1]
                        source = source.replace('.castep', '')
                        source = ''.join(source.split('_')[:-1])
                        structure_files[source].append(castep_dict)
    chdir('..')
    return structure_files


def get_data(structure_files, conv_field='cut_off_energy'):
    """ Parse cutoff energy/kpt spacing convergence calculations from list of files.

    Parameters:
        structure_files (list): list of filenames.

    Keyword arguments:
        conv_field (str): field for convergence parameter.

    """
    scraped_from_filename = None

    form_set = set([get_formula_from_stoich(structure_files[structure][0]['stoichiometry']) for structure in structure_files])
    if len(form_set) == 1:
        chempot_mode = False
        single = True
        print('Working in single stoichiometry mode..')

    else:
        print('Searching for chemical potentials')
        chempot_mode = True
        single = False
        chempots_dict = defaultdict(dict)
        chempots = defaultdict(list)
        chempot_list = dict()

    if conv_field == 'kpoints_mp_spacing':
        rounding = None
    else:
        rounding = None

    if chempot_mode:
        for key in structure_files:
            for doc in structure_files[key]:
                if conv_field == 'kpoints_mp_spacing':
                    scraped_from_filename = float(doc['source'][0].split('/')[-1].split('_')[-1].split('A')[0])
                try:
                    if not single and len(doc['stoichiometry']) == 1:
                        doc['formation_energy_per_atom'] = 0
                        if scraped_from_filename is not None:
                            rounded_field = round(scraped_from_filename, rounding)
                        else:
                            rounded_field = round(doc[conv_field], rounding)
                        if 'total_energy_per_atom' in doc:
                            chempots_dict[str(rounded_field)][doc['atom_types'][0]] = doc['total_energy_per_atom']
                            chempots[key].append([rounded_field, doc['total_energy_per_atom']])
                        else:
                            chempots_dict[str(rounded_field)][doc['atom_types'][0]] = doc['enthalpy_per_atom']
                            chempots[key].append([rounded_field, doc['enthalpy_per_atom']])
                        chempot_list[key] = doc['stoichiometry'][0][0]
                except:
                    print('Error with {} and {} = {}'.format(key, conv_field, rounded_field))
                    print_exc()
                    exit()

    form = defaultdict(list)
    forces = defaultdict(list)
    stoich_list = dict()
    elems = set()

    if chempot_mode:
        for value in chempots_dict:
            for elem in chempots_dict[value]:
                elems.add(elem)

        for value in chempots_dict:
            for elem in elems:
                if elem not in chempots_dict[value]:
                    print('WARNING: {} chemical potential missing at {} = {} eV, skipping this value.'.format(elem, conv_field, value))

    for key in structure_files:
        for doc in structure_files[key]:
            if conv_field == 'kpoints_mp_spacing':
                scraped_from_filename = float(doc['source'][0].split('/')[-1].split('_')[-1].split('A')[0])
            if single or len(doc['stoichiometry']) != 1:
                try:
                    if 'total_energy_per_atom' in doc:
                        doc['formation_energy_per_atom'] = doc['total_energy_per_atom']
                    else:
                        doc['formation_energy_per_atom'] = doc['enthalpy_per_atom']
                    if scraped_from_filename is not None:
                        rounded_field = round(scraped_from_filename, rounding)
                    else:
                        rounded_field = round(doc[conv_field], rounding)
                    if chempot_mode:
                        for atom in doc['atom_types']:
                            doc['formation_energy_per_atom'] -= chempots_dict[str(rounded_field)][atom] / len(doc['atom_types'])
                    form[key].append([rounded_field, doc['formation_energy_per_atom']])
                    stoich_list[key] = get_formula_from_stoich(doc['stoichiometry'], tex=True)
                    if 'forces' in doc:
                        forces[key].append([rounded_field, np.sqrt(np.sum(np.asarray(doc['forces'])**2, axis=-1))])
                        assert len(forces[key][-1][-1]) == len(doc['atom_types'])
                except:
                    print_exc()
                    pass

    for key in form:
        if conv_field == 'kpoints_mp_spacing':
            reverse = True
        else:
            reverse = False
        form[key] = sorted(form[key], key=lambda x: x[0], reverse=reverse)
        form[key] = np.asarray(form[key])

        forces[key] = sorted(forces[key], key=lambda x: x[0], reverse=reverse)

    if chempot_mode:
        for key in chempots:
            chempots[key] = sorted(form[key], key=lambda x: x[0])
            chempots[key] = np.asarray(chempots[key])

    data = dict()
    data['form'] = form
    data['forces'] = forces
    data['stoich_list'] = stoich_list
    if chempot_mode:
        data['chempots'] = chempots
        data['chempot_list'] = chempot_list

    return data


@plotting_function
def plot_both(plot_cutoff=False, plot_kpt=False, cutoff_data={}, kpt_data={}, log=True, show_chempots=False, **kwargs):
    """ Plot convergence of either cutoff/kpts or both. """
    import matplotlib.pyplot as plt
    num_structures = max(len(cutoff_data['stoich_list'] if 'stoich_list' in cutoff_data else [1]), len(kpt_data['stoich_list'] if 'stoich_list' in kpt_data else [1]))
    print('Number of structures: {}'.format(num_structures))
    if plot_cutoff and plot_kpt and 'forces' in cutoff_data and 'forces' in kpt_data:
        fig = plt.figure()
        forces = True
        ax_cutoff = fig.add_subplot(221)
        ax_cutoff_forces = fig.add_subplot(223, sharex=ax_cutoff)
        ax_kpt = fig.add_subplot(222)
        ax_kpt_forces = fig.add_subplot(224, sharex=ax_kpt)
    elif plot_cutoff and plot_kpt:
        fig = plt.figure()
        ax_cutoff = fig.add_subplot(211)
        ax_kpt = fig.add_subplot(212)
    elif plot_cutoff and not plot_kpt:
        fig, ax_cutoff = plt.subplots(1, 1)
    elif plot_kpt and not plot_cutoff:
        fig, ax_kpt = plt.subplots(1, 1)

    if plot_cutoff:
        xpoints = []
        ypoints = []
        lines = []
        labels = []
        cutoff_form = cutoff_data['form']
        if forces:
            cutoff_forces = cutoff_data['forces']
        if 'chempots' in cutoff_data:
            cutoff_chempots = cutoff_data['chempots']
        for ind, key in enumerate(cutoff_form):
            try:
                relative_energies = 1000*np.abs(cutoff_form[key][:, 1] - cutoff_form[key][-1, 1])
            except:
                print('Issue with {}: {}'.format(key, cutoff_form[key]))
                continue
            if log:
                x, y = -1/cutoff_form[key][:, 0], np.log10(relative_energies)
            else:
                x, y = -1/cutoff_form[key][:, 0], relative_energies

            line, = ax_cutoff.plot(x, y,
                                   'o', markersize=5, alpha=1, label=key, lw=0, zorder=1000)
            point, = ax_cutoff.plot(x, y,
                                    '-', alpha=0.2, label=key, lw=1, zorder=1000, c=line.get_color())

            xpoints.append(x)
            ypoints.append(y)

            lines.append(line)
            labels.append(key)
            if show_chempots and 'chempots' in cutoff_data:
                for key in cutoff_chempots:
                    ax_cutoff.plot(-1/cutoff_chempots[key][:, 0], np.log10((cutoff_chempots[key][:, 1]-cutoff_chempots[key][-1, 1])*1000),
                                   'o-', markersize=5, alpha=1, label=key, lw=1, c='grey')
            if forces:
                try:
                    for ind, value in enumerate(cutoff_forces[key]):
                        relative_forces = np.abs(np.asarray(cutoff_forces[key][ind][1]) - np.asarray(cutoff_forces[key][-1][1]))
                        ax_cutoff_forces.plot(len(cutoff_forces[key][ind][1])*[-1/value[0]], relative_forces, alpha=0.2, c=lines[-1].get_color())
                        ax_cutoff_forces.scatter(-1/value[0], np.mean(relative_forces), alpha=0.5, c=lines[-1].get_color())
                except:
                    print('Issue with {}: {}'.format(key, cutoff_form[key]))

        if forces:
            ax_cutoff_forces.set_ylabel('Force eV/A')
            if kwargs['force_range']:
                ax_cutoff_forces.set_ylim(0, kwargs['force_range'])
            if kwargs['force_target']:
                ax_cutoff_forces.axhline(kwargs['force_target'], lw=0.5, c='r', ls='--')

        if log:
            ax_cutoff.set_ylabel('log(Relative energy difference (meV/atom))')
        else:
            ax_cutoff.set_ylabel('Relative energy difference (meV/atom)')
        ax_cutoff.set_xlabel('1 / plane wave cutoff (eV)')
        try:
            min_ = 1e10
            max_ = 0
            for key in cutoff_form:
                try:
                    min_cutoff = np.min(cutoff_form[key][:, 0])
                    if min_cutoff < min_:
                        min_ = min_cutoff
                    max_cutoff = np.max(cutoff_form[key][:, 0])
                    if max_cutoff > max_:
                        max_ = max_cutoff
                except:
                    pass
            xlabels = np.arange(min_, 2*max_, step=100)
            xlabels_str = ['1/{:3.0f}'.format(val[1]) for val in enumerate(xlabels)]
            for i in range(len(xlabels_str)):
                if i % 2 == 1 or xlabels[i] > max_:
                    xlabels_str[i] = ''
            ax_cutoff.set_xticks(-1/xlabels)
            ax_cutoff.set_xticklabels(xlabels_str)
            ax_cutoff.set_xlim(-1.1*1/min_, -0.9*1/max_)
            if kwargs['energy_range']:
                ax_cutoff.set_ylim(0, kwargs['energy_range'])
        except:
            print_exc()
            print('No cutoff.conv file found, axis labels may be ugly...')

        max_y = dict()
        for ind, points in enumerate(xpoints):
            for jnd, x in enumerate(points):
                if (max_y.get(x) is None or ypoints[ind][jnd] > max_y[x]):
                    max_y[x] = ypoints[ind][jnd]

        max_x_vals = np.asarray(sorted(max_y.keys()))
        max_y_vals = np.asarray([max_y[x] for x in max_x_vals])
        if kwargs['energy_target']:
            ax_cutoff.fill_between(max_x_vals[np.where(max_y_vals <= kwargs['energy_target'])], max_y_vals[np.where(max_y_vals <= kwargs['energy_target'])], alpha=0.2, color='green')
            ax_cutoff.fill_between(max_x_vals[np.where(max_y_vals > kwargs['energy_target'])], max_y_vals[np.where(max_y_vals > kwargs['energy_target'])], alpha=0.2, color='red')
        else:
            ax_cutoff.fill_between(max_x_vals, max_y_vals, alpha=0.2, color='grey')

    if plot_kpt:
        lines = []
        labels = []
        xpoints = []
        ypoints = []
        kpt_form = kpt_data['form']
        if forces:
            kpt_forces = kpt_data['forces']
        if 'chempots' in kpt_data:
            kpt_chempots = kpt_data['chempots']
        for ind, key in enumerate(kpt_form):
            try:
                relative_energies = 1000*np.abs(kpt_form[key][:, 1] - kpt_form[key][-1, 1])
            except:
                print('Issue with {}: {}'.format(key, kpt_form[key]))
                continue
            if log:
                x, y = kpt_form[key][:, 0], np.log10(relative_energies)
            else:
                x, y = kpt_form[key][:, 0], relative_energies

            xpoints.append(x)
            ypoints.append(y)

            line, = ax_kpt.plot(x, y,
                                'o', markersize=5, alpha=1, label=key, lw=0, zorder=1000)
            point, = ax_kpt.plot(x, y,
                                 '-', alpha=0.2, label=key, c=line.get_color(), lw=1, zorder=1000)
            lines.append(line)
            labels.append(key)
            if show_chempots and 'chempots' in kpt_data:
                for key in kpt_chempots:
                    ax_kpt.plot(kpt_chempots[key][:, 0], (kpt_chempots[key][:, 1]-kpt_chempots[key][0, 1])*1000, 'o-', markersize=5, alpha=1, label=key, lw=1, c='grey')

            if forces:

                try:
                    for ind, value in enumerate(kpt_forces[key]):
                        relative_forces = np.abs(np.asarray(kpt_forces[key][ind][1]) - np.asarray(kpt_forces[key][-1][1]))
                        ax_kpt_forces.plot(len(kpt_forces[key][ind][1])*[value[0]], relative_forces, alpha=0.2, c=lines[-1].get_color())
                        ax_kpt_forces.scatter(value[0], np.mean(relative_forces), alpha=0.5, c=lines[-1].get_color())
                except:
                    print('Issue with {}: {}'.format(key, kpt_form[key]))

        max_y = dict()
        for ind, points in enumerate(xpoints):
            for jnd, x in enumerate(points):
                if (max_y.get(x) is None or ypoints[ind][jnd] > max_y[x]):
                    max_y[x] = ypoints[ind][jnd]

        max_x_vals = np.asarray(sorted(max_y.keys()))
        max_y_vals = np.asarray([max_y[x] for x in max_x_vals])

        if kwargs['energy_target']:
            ax_kpt.axhline(kwargs['energy_target'], lw=0.5, c='r', ls='--')
            ax_kpt.fill_between(max_x_vals[np.where(max_y_vals <= kwargs['energy_target'])], max_y_vals[np.where(max_y_vals <= kwargs['energy_target'])], alpha=0.2, color='green')
            ax_kpt.fill_between(max_x_vals[np.where(max_y_vals > kwargs['energy_target'])], max_y_vals[np.where(max_y_vals > kwargs['energy_target'])], alpha=0.2, color='red')
        else:
            ax_kpt.fill_between(max_x_vals, max_y_vals, alpha=0.2, color='grey')

        if forces:
            ax_kpt_forces.set_ylabel('Force eV/A')
            if kwargs['force_range']:
                ax_kpt_forces.set_ylim(0, kwargs['force_range'])
            if kwargs['force_target']:
                ax_kpt_forces.axhline(kwargs['force_target'], lw=0.5, c='r', ls='--')

        if plot_cutoff:
            ax_kpt.set_xlim(ax_kpt.get_xlim()[1], ax_kpt.get_xlim()[0])
            ax_kpt.yaxis.tick_right()
            ax_kpt.yaxis.set_label_position('right')
            if forces:
                ax_kpt_forces.yaxis.tick_right()
                ax_kpt_forces.yaxis.set_label_position('right')
            if log:
                ax_kpt.set_ylabel('log(Relative energy difference (meV/atom))')
            else:
                ax_kpt.set_ylabel('Relative energy difference (meV/atom)')
            if kwargs['energy_range']:
                ax_kpt.set_ylim(0, kwargs['energy_range'])

    plt.figlegend(lines, labels, loc='upper center', fontsize=10, ncol=2, frameon=True, fancybox=True, shadow=True)
    plt.savefig('conv.png', bbox_inches='tight')
    plt.show()


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser(prog='plot_convergence')
    parser.add_argument('--log', action='store_true', help='Plot log energies')
    parser.add_argument('--show_chempots', action='store_true', help='Include chempots in plot')
    parser.add_argument('--only', type=str, help='Show only convergence of this seedname')
    parser.add_argument('--energy-range', '--energy_range', type=float, help='Plot up to this energy value')
    parser.add_argument('--force-range', '--force_range', type=float, help='Plot up to this force value')
    parser.add_argument('--energy-target', '--energy_target', type=float, help='Mark this convergence target on the plot')
    parser.add_argument('--force-target', '--force_target', type=float, help='Mark this convergence target on the plot')
    args = parser.parse_args()
    kwargs = vars(args)
    if kwargs.get('only') is not None:
        kwargs['only'] = kwargs['only'].split('.')[0]
    try:
        cutoff = True
        kpts = True
        if not isdir('completed_cutoff'):
            cutoff = False
            cutoff_data = {}
            print('Did not find completed_cutoff folder, skipping cutoffs...')
        if not isdir('completed_kpts'):
            kpts = False
            kpt_data = {}
            print('Did not find completed_kpts folder, skipping kpts...')
        if not cutoff and not kpts:
            exit('Could not find any completed_$x folders!')
        cutoff = False
        kpts = False
        if isdir('completed_cutoff'):
            print('Parsing cutoffs...')
            cutoff = True
            cutoff_structure_files = get_files('completed_cutoff', only=kwargs.get('only'))
            cutoff_data = get_data(cutoff_structure_files, conv_field='cut_off_energy')
        if isdir('completed_kpts'):
            print('Parsing kpts...')
            kpts = True
            kpt_structure_files = get_files('completed_kpts', only=kwargs.get('only'))
            kpt_data = get_data(kpt_structure_files, conv_field='kpoints_mp_spacing')
        plot_both(plot_cutoff=cutoff, plot_kpt=kpts, cutoff_data=cutoff_data, kpt_data=kpt_data, **kwargs)
    except:
        print_exc()
        print('This script is rubbish, please contact me388@cam.ac.uk and tell him to fix it.')
