#!/usr/bin/env python

import warnings
import sys
import traceback
import os
from itertools import groupby
from fanc.commands import fancplot_command_parsers as fancplot_parsers
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')

logger = logging.getLogger(__name__)


class FancPlot(object):

    def __init__(self):
        #
        # define parsers for fancplot and subplots
        #

        # global parser
        global_parser = fancplot_parsers.fancplot_parser()
        self.global_parser = global_parser

        subplot_parser = fancplot_parsers.type_parser()
        self.subplot_parser = subplot_parser

        argv = sys.argv[1:]
        if not argv:
            # No arguments supplied, print help
            argv = ["-h"]
        # sort and parse arguments
        global_args, subplot_data = self._parse_argument_vector(argv)

        if global_args.print_version:
            import fanc
            print(fanc.__version__)
            exit()

        if global_args.script is not None:
            logger.info("Parsing script file")
            script_global_args, script_subplots = self._parse_script(global_args.script)
            d_global = vars(global_args)
            d_script_global = vars(script_global_args)
            for key, value in d_script_global.items():
                # global parameters can only be overridden if not explicitly specified
                if d_global[key] == self.global_parser.get_default(key):
                    d_global[key] = value
            subplot_data = script_subplots + subplot_data

        if len(subplot_data) == 0:
            self.global_parser.error("Need to provide at least one plot")

        # configure output
        output_file = None
        if global_args.output is not None:
            logger.info("Using non-interactive backend")
            import matplotlib
            matplotlib.use("Agg")
            output_file = os.path.expanduser(global_args.output)

        try:
            import fanc.plotting as kplot
            from fanc.commands import fancplot_commands as commands
        except ModuleNotFoundError:
            self.global_parser.error("Your Python installation does not seem to support "
                                     "interactive plotting. Please use fancplot -o ... to plot "
                                     "to file or install a backend that supports interactive "
                                     "plots, such as Tkinter!")
            raise

        from fanc import config as fanc_config
        if global_args.pdf_text_as_font:
            logger.debug("Using text for PDF export instead of paths")
            matplotlib.rcParams['pdf.fonttype'] = 42

        plots = []
        all_subplot_args = []
        for plot_method_name, args in subplot_data:
            plot_method = None
            try:
                plot_method = getattr(commands, plot_method_name)
            except AttributeError:
                self.global_parser.error("Unrecognised plot type {}".format(plot_method_name))
            # create plot
            plot, plot_args = plot_method(args)
            plots.append(plot)

            if isinstance(plot, kplot.base_plotter.BasePlotter):
                # register aspect ratio
                if plot_args.aspect_ratio is not None:
                    plot.aspect = plot_args.aspect_ratio

                plot.fix_chromosome = plot_args.fix_chromosome
                #plot._draw_x_axis = not plot_args.hide_x

                # set plot title
                plot.title = plot_args.title
            all_subplot_args.append(plot_args)

        # parse regions
        regions = []
        if global_args.regions is not None:
            for r in global_args.regions:
                path = os.path.expanduser(r)
                if os.path.isfile(path):
                    import genomic_regions as gr
                    path_regions = gr.load(path)
                    if not hasattr(path_regions, 'regions'):
                        raise ValueError("Provided file does not contain any regions ({})".format(path))
                    regions += list(path_regions.regions)
                else:
                    import genomic_regions as gr
                    region = gr.as_region(r)
                    regions.append(region)

        if len(regions) == 0:
            raise RuntimeError("Must provide at least one region for plotting!")

        logger.info("Found {} regions".format(len(regions)))

        if len(regions) > 1 and output_file is not None:
            from fanc.tools.general import mkdir
            mkdir(output_file)

        figure_width = global_args.width

        import fanc.plotting as kplt
        from fanc.tools.general import str_to_int
        for i, region in enumerate(regions):
            if global_args.window_size is not None:
                window_size = str_to_int(global_args.window_size)
                chromosome = region.chromosome
                start = max(1, region.center - int(window_size / 2))
                end = region.center + int(window_size / 2)
                import fanc
                plot_region = fanc.GenomicRegion(chromosome=chromosome, start=start, end=end)
            else:
                plot_region = region

            with kplt.GenomicFigure(plots, width=figure_width, invert_x=global_args.invert_x) as gf:
                try:
                    fig, axes = gf.plot(plot_region)
                    for plot_ix, plot in enumerate(plots):
                        plot_args = all_subplot_args[plot_ix]

                        if global_args.tick_locations is not None:
                            tick_locations = [str_to_int(x) for x in global_args.tick_locations]

                            axes[plot_ix].set_xticks(tick_locations)

                        plot.remove_genome_ticks(minor=not plot_args.show_minor_ticks,
                                                 major=not plot_args.show_major_ticks)
                        if not plot_args.show_tick_legend:
                            plot.remove_tick_legend()
                        if plot_args.hide_x:
                            plot.remove_genome_labels()
                except Exception as error:
                    tb = traceback.format_exc()
                    warnings.warn("There was an error with plot {}, region {}:{}-{} ({})".format(i, region.chromosome,
                                                                                                 region.start,
                                                                                                 region.end,
                                                                                                 tb))
                    logger.warning(error)
                    continue

                if output_file is not None:
                    if len(regions) == 1:
                        fig.savefig(output_file)
                    else:
                        if global_args.name == '':
                            fig.savefig(output_file + '/{}_{}_{}-{}.pdf'.format(i, region.chromosome,
                                                                                region.start, region.end))
                        else:
                            fig.savefig(output_file + '/{}_{}_{}_{}-{}.pdf'.format(i, global_args.name,
                                                                                   region.chromosome,
                                                                                   region.start, region.end))
                else:
                    kplt.plt.show()
                kplt.plt.close(fig)

    def _parse_argument_vector(self, argv):
        subplot_data = []
        group_counter = 0
        global_args = None
        for k, group in groupby(argv, lambda x: x == "--plot" or x == '-p'):
            if group_counter == 0:
                if not k:
                    global_args = self.global_parser.parse_args(list(group))
                else:
                    global_args = self.global_parser.parse_args([])
            elif not k:
                sub_args, plot_args = self.subplot_parser.parse_known_args(list(group))
                subplot_data.append((sub_args.type, sub_args.data + plot_args))
            group_counter += 1

        return global_args, subplot_data

    def _parse_script(self, script):
        script = os.path.expanduser(script)

        if not os.path.isfile(script):
            print('fancplot: error: This is not a script file: {}'.format(script))

        import re
        import shlex
        re_equal_sign = re.compile(r"\s*=\s*")
        parameters = []
        with open(script, 'r') as s:
            for i, line in enumerate(s):
                if line.startswith("#"):
                    continue

                cmd_val = re_equal_sign.split(line.rstrip())
                if len(cmd_val) > 2:
                    print('fancplot: error: Too many equal-signs on line {}'.format(i))

                if len(cmd_val) == 1:
                    if re.match(r"^\s*$", cmd_val[0]):
                        continue
                    parameters.append('--' + cmd_val[0])
                elif len(cmd_val) == 2:
                    cmd, val = cmd_val
                    cmd = '--' + cmd.strip()
                    if cmd != '--regions' and cmd != '--plot' and cmd != '--data':
                        parameters.append(cmd)
                    sub_parameters = shlex.split(val)
                    parameters += sub_parameters

        return self._parse_argument_vector(parameters)


if __name__ == '__main__':
    FancPlot()
