#!/usr/bin/env python
""" This script computes, plots and optionally saves PXRD patterns from cif input files. """
import argparse
import time
from traceback import print_exc
import matplotlib.pyplot as plt
from matador import script_epilog
from matador.scrapers import cif2dict
from matador.plotting.pxrd_plotting import plot_pxrd
from matador.crystal import Crystal


def compute_pxrd(**kwargs):
    """ Take res/cif files from command-line, calculate PXRD, then plot or save.

    """
    strucs = []
    seeds = kwargs.get('seeds')
    if isinstance(seeds, str):
        seeds = [seeds]
    broadening_width = kwargs.get('broadening_width')
    wavelength = kwargs.get('wavelength')
    theta_m = kwargs.get('theta_m')

    for _file in seeds:
        start = time.time()
        struc, success = cif2dict(_file, fail_fast=True)
        elapsed = time.time() - start
        print(f"Loaded CIF {_file} in {elapsed:3f} s")
        if not success:
            print(f"Error parsing {_file}:")
            print(struc)
            exit()
        start = time.time()
        try:
            crystal = Crystal(struc)
            strucs.append(crystal)
        except Exception:
            print(f"Error loading {_file} as Crystal object:")
            print_exc()
            exit()
        elapsed = time.time() - start

    for doc in strucs:
        start = time.time()
        doc.calculate_pxrd(
            wavelength=wavelength,
            theta_m=theta_m,
            lorentzian_width=broadening_width,
            progress=doc.num_atoms > 100  # show progress bar if cell is large
        )
        elapsed = time.time() - start
        print(f"Computed PXRD in {elapsed:3f} s")

    if kwargs.get('plot') or kwargs.get('savefig'):
        plot_pxrd([doc.pxrd for doc in strucs])
        if kwargs.get('savefig'):
            plt.savefig(kwargs.get('savefig'))
        else:
            plt.show()

    if kwargs.get('save_patterns'):
        for doc in strucs:
            doc.pxrd.save_pattern(doc.root_source + '_pxrd_pattern.dat')

    if kwargs.get('save_peaks'):
        for doc in strucs:
            doc.pxrd.save_peaks(doc.root_source + '_pxrd_peaks.dat')


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description="Compute, plot and export PXRD patterns from CIF file inputs.",
        epilog=script_epilog
    )
    parser.add_argument('-l', '--wavelength', type=float, default=1.5406)
    parser.add_argument('-bw', '--broadening_width', type=float, default=0.03)
    parser.add_argument('-tm', '--theta_m', type=float, default=0.0)
    parser.add_argument('--plot', action='store_true', help='show a plot of the PXRD patterns')
    parser.add_argument('--savefig', type=str, help='save a plot to this file, e.g. "pxrd.pdf"')
    parser.add_argument('--save_patterns', action='store_true', help='save a .dat file with the xy pattern for each structure')
    parser.add_argument('--save_peaks', action='store_true', help='save a .txt file per structure with a list of peaks')
    parser.add_argument('seeds', nargs='+', type=str, help='list of structures to compute')
    parsed_kwargs = vars(parser.parse_args())
    compute_pxrd(**parsed_kwargs)
    print('Done!')
