from functools import wraps
from pathlib import Path
from typing import Any, Dict, List, Union, Optional

import matplotlib.pyplot as plt
import numpy as np

from plotagain.iddict import IDDict
from plotagain.pyplotcall import PyplotCall
from plotagain.utils import write_pickle

here = Path('.')
pickle_load_template = "{var_name} = load_pickle('{var_name}.pkl')"


class SavePlotContext:
    """
    Context manager which wraps matplotlib.pyplot. Usage:

    with SavePlotContext("/.../save-dir/", locals(), ...) as pla:
        pla.plot( ... )
        ...
        pla.title( ... )
        pla.show()

    Any calls to pla are passed to matplotlib.pyplot and all args and kwargs are stored. On exiting the 'with' block,
    all args/kwargs are saved to save-dir/ as pickle files and a script is autogenerated to recreate the produced plots.
    Using the locals() dict, the variable names of the objects passed in as arguments are inferred and re-used in the
    autogenerated script. Any objects which don't appear in the locals() dict are given the name 'unnamed_arg'
    post-fixed with an integer if multiple unnamed args are present.

    Attributes
    ----------
    save_dir
        The path to a directory into which the argument pickles and autogenerated script are saved. If the directory
        does not exist it is created. If the directory is not empty and self.overwrite is False, an exception is
        raised
    outer_locals_dict
        The locals() dict containing the variables in the local scope of the code which created the SavePlotContext
    outer_globals_dict
        The globals() dict containing the variables in the global scope of the code which created the SavePlotContext
    overwrite
        If False, the directory self.save_dir must be empty, otherwise an exception is raised
    regenerate_script_name
        The name of the autogenerated script
    script_template_path
        A path to the template used to autogenerate the script which recreates the plots created using this
        SavePlotContext. See script_template.txt for an example
    used_variables
        A mapping from the variable name to the objects passed into any matplotlib.pyplot call
    calls
        A list of PyplotCall objects, one for each call made to matplotlib.pyplot. E.g. 'plot', 'show', 'hist'
    """
    save_dir: Path
    outer_locals_dict: Dict[str, Any]
    outer_globals_dict: Dict[str, Any]
    overwrite: bool
    regenerate_script_name: str

    used_variables: Dict[str, Any]
    calls: List[PyplotCall]

    def __init__(
        self,
        save_dir: Union[Path, str],
        locals_dict: Dict[str, Any],
        overwrite: bool = False,
        globals_dict: Optional[Dict[str, Any]] = None,
        script_name: str = 'make_plot.py',
        script_template: Union[Path, str] = here / 'script_template.txt',
    ):
        """

        Parameters
        ----------
        save_dir
            The path to a directory into which the argument pickles and autogenerated script are saved. If the directory
            does not exist it is created. If the directory is not empty and self.overwrite is False, an exception is
            raised
        locals_dict
            The dict returned by 'locals()' called in the scope of the plotting code
        globals_dict
            The dict returned by 'globals()' called in the scope of the plotting code
        overwrite
            If False, the directory self.save_dir must be empty, otherwise an exception is raised
        script_name
            The name of the autogenerated script
        script_template
            A path to the template used to autogenerate the script which recreates the plots created using this
            SavePlotContext. See script_template.txt for an example
        """
        self.save_dir = Path(save_dir)
        self.outer_locals_dict = locals_dict
        self.outer_globals_dict = globals_dict or {}
        self.overwrite = overwrite
        self.regenerate_script_name = script_name
        self.script_template_path = Path(script_template)

        self.used_variables = {}
        self.calls = []

        if not self.save_dir.exists():
            self.save_dir.mkdir()
        if not self.overwrite and len(list(self.save_dir.glob('*'))) > 0:
            raise Exception(f"Save directory '{self.save_dir.absolute()}' is not empty and overwrite is False")

    def __getattr__(self, fn_name: str) -> Any:
        """
        Intended use is to wrap the matplotlib.pyplot module and return a dummy wrapper which stores the arguments of
        the call and return the return value from matplotlib.pyplot . If matplotlib.pyplot.fn_name is not callable, the
        bare attribute is returned

        Parameters
        ----------
        fn_name
            The name of the called matplotlib.pyplot function. E.g. 'plot', 'show', 'hist'

        Returns
        -------
        Any
            A wrapper function which stores the arguments of the call and returns the return value from
            matplotlib.pyplot . If matplotlib.pyplot.fn_name is not callable, the
            bare attribute is returned
        """
        plt_fn = getattr(plt, fn_name)
        if not callable(plt_fn):
            return plt_fn

        @wraps(plt_fn)
        def wrapper(*args, **kwargs):
            """
            Wrapper function which simply calls the intended matplotlib.pyplot function, stores the arguments, and
            returns the result
            """
            ret = plt_fn(*args, **kwargs)
            self.calls.append(PyplotCall(
                fn_name,
                args,
                kwargs
            ))
            calling_scope = IDDict(self.outer_locals_dict)
            calling_scope.update(self.outer_globals_dict)
            self.calls[-1].find_or_name_args(calling_scope, self.used_variables)
            return ret

        return wrapper

    def save_used_variables(self) -> None:
        """
        Saves a pickle file for each object passed in as an argument to a matplotlib.pyplot call
        """
        for var_name, value in self.used_variables.items():
            write_pickle(self.save_dir / f'{var_name}.pkl', value)

    def save_regenerate_script(self) -> None:
        """
        Creates and saves a python script which reproduces all the plots produced using this SavePlotContext instance
        """
        with open(self.script_template_path, 'r') as f:
            script_template = f.read()
        load_variables_code = '\n'.join(
            pickle_load_template.format(var_name=var_name) for var_name in self.used_variables.keys()
        )
        do_plot_code = '\n'.join(call.render_recall() for call in self.calls)

        script_code = script_template.format(
            load_variables_code=load_variables_code,
            do_plot_code=do_plot_code,
        )
        with open(self.save_dir / self.regenerate_script_name, 'w') as f:
            f.write(script_code)

    def save(self) -> None:
        """
        Saves all matplotlib.pyplot arguments and creates the script needed to reproduce the plots produced using this
        SavePlotContext instance
        """
        self.save_used_variables()
        self.save_regenerate_script()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.save()
        return False


if __name__ == '__main__':
    x_data = np.linspace(0, 2 * np.pi, 1000)
    y_data = np.sin(x_data)
    with SavePlotContext("./data-save-dir", locals(), overwrite=True) as spl:
        spl.plot(x_data, y_data, c='k', label='sin')
        spl.plot(x_data, np.cos(x_data), c='b', label='cos')
        spl.xlabel('xaxis')
        spl.ylabel('yaxis')
        spl.title('Title')
        spl.legend()

        spl.savefig('plot.pdf')
        spl.show()
