"""
DAG module

A DAG is collection of tasks that makes sure they are executed in
the right order
"""
import traceback
from copy import copy, deepcopy
from pathlib import Path
import warnings
import logging
import collections
import tempfile

try:
    import importlib.resources as importlib_resources
except ImportError:
    # backported
    import importlib_resources


import mistune
import pygments
import networkx as nx
from tqdm.auto import tqdm
from jinja2 import Template

from ploomber.Table import Table, TaskReport, BuildReport
from ploomber.products import MetaProduct
from ploomber.util import (image_bytes2html, isiterable, path2fig, requires,
                           markup)
from ploomber.CodeDiffer import CodeDiffer
from ploomber import resources
from ploomber import executors
from ploomber.constants import TaskStatus, DAGStatus
from ploomber.exceptions import DAGBuildError, DAGRenderError
from ploomber.ExceptionCollector import ExceptionCollector
from ploomber.util.util import callback_check
from ploomber.dag.DAGConfiguration import DAGConfiguration


class DAG(collections.abc.Mapping):
    """A collection of tasks with dependencies

    Parameters
    ----------
    name : str, optional
        A name to identify this DAG
    clients : dict, optional
        A dictionary with classes as keys and clients as values, can be
        later modified using dag.clients[dag] = client
    differ : CodeDiffer
        An object to determine whether two pieces of code are the same and
        to output a diff, defaults to CodeDiffer() (default parameters)
    executor : str or ploomber.executors instance, optional
        The executor to use (ploomber.executors.Serial and
        ploomber.executors.Parallel), is a string is passed ('serial'
        or 'parallel') the corresponding executor is initialized with default
        parameters
    """
    def __init__(self, name=None, clients=None, differ=None,
                 executor='serial'):
        self._G = nx.DiGraph()

        self.name = name or 'No name'
        self.differ = differ or CodeDiffer()
        self._logger = logging.getLogger(__name__)

        self._clients = clients or {}
        self.__exec_status = DAGStatus.WaitingRender

        if executor == 'serial':
            self._executor = executors.Serial()
        elif executor == 'parallel':
            self._executor = executors.Parallel()
        elif isinstance(executor, executors.Executor.Executor):
            self._executor = executor
        else:
            raise TypeError('executor must be "serial", "parallel" or '
                            'an instance of executors.Executor, got type {}'
                            .format(type(executor)))

        self._did_render = False

        self.on_finish = None
        self.on_failure = None
        self._available_callback_kwargs = {'dag': self}

        self._cfg = DAGConfiguration.default()

    @property
    def _exec_status(self):
        return self.__exec_status

    @_exec_status.setter
    def _exec_status(self, value):
        self._logger.debug('Setting %s status to %s', self, value)

        # The Task class is responsible for updating their status
        # (except for Executed and Errored, those are updated by the executor)
        # DAG should not set task status but only verify that after an attemp
        # to change DAGStatus, all Tasks have allowed states, otherwise there
        # is an error in either the Task or the Executor. we cannot raise an
        # exception here, since setting _exec_status happens might happen
        # right before catching an exception, but we still have to warn the
        # user that the DAG entered an inconsistent state. We only raise
        # an exception when trying to set an invalid value
        # NOTE: in some exec_status, it is ok to raise an exception, maybe we
        # should do it?

        if value == DAGStatus.WaitingRender:
            self.check_tasks_have_allowed_status({TaskStatus.WaitingRender},
                                                 value)

        # render errored
        elif value == DAGStatus.ErroredRender:
            allowed = {TaskStatus.WaitingExecution, TaskStatus.WaitingUpstream,
                       TaskStatus.ErroredRender, TaskStatus.AbortedRender,
                       TaskStatus.Skipped}
            self.check_tasks_have_allowed_status(allowed, value)

        # rendering ok, waiting execution
        elif value == DAGStatus.WaitingExecution:
            exec_values = set(task.exec_status for task in self.values())
            allowed = {TaskStatus.WaitingExecution,
                       TaskStatus.WaitingUpstream,
                       TaskStatus.Skipped}
            self.check_tasks_have_allowed_status(allowed, value)

        # attempted execution but failed
        elif value == DAGStatus.Executed:
            exec_values = set(task.exec_status for task in self.values())
            # check len(self) to prevent this from failing on an empty DAG
            if not exec_values <= {TaskStatus.Executed,
                                   TaskStatus.Skipped} and len(self):
                warnings.warn('The DAG "{}" entered in an inconsistent '
                              'state: trying to set DAG status to '
                              'DAGStatus.Executed but executor '
                              'returned tasks whose status is not '
                              'TaskStatus.Executed nor '
                              'TaskStatus.Skipped, returned '
                              'status: {}'.format(self.name, exec_values))
        elif value == DAGStatus.Errored:
            # no value validation since this state is also set then the
            # DAG executor ends up abrubtly
            pass
        else:
            raise RuntimeError('Unknown DAGStatus value: {}'
                               .format(value))

        self.__exec_status = value

    def check_tasks_have_allowed_status(self, allowed, new_status):
        exec_values = set(task.exec_status for task in self.values())
        if not exec_values <= allowed:
            warnings.warn('The DAG "{}" entered in an inconsistent state: '
                          'trying to set DAG status to '
                          '{} but executor '
                          'returned tasks whose status is not in a '
                          'subet of {}. Returned '
                          'status: {}'.format(self.name, new_status, allowed,
                                              exec_values))

    @property
    def product(self):
        # We have to rebuild it since tasks might have been added
        return MetaProduct([t.product for t in self.values()])

    @property
    def clients(self):
        return self._clients

    def pop(self, name):
        """Remove a task from the dag
        """
        t = self._G.nodes[name]['task']
        self._G.remove_node(name)
        return t

    def render(self, force=False, show_progress=True):
        """Render the graph
        """
        g = self._to_graph()

        def unique(elements):
            elements_unique = []
            for elem in elements:
                if elem not in elements_unique:
                    elements_unique.append(elem)
            return elements_unique

        dags = unique([t.dag for t in g])

        # first render any other dags involved (this happens when some
        # upstream parameters come form other dags)
        # NOTE: for large compose dags it might be wasteful to render over
        # and over
        for dag in dags:
            if dag is not self:
                dag._render_current(force=force, show_progress=show_progress)

        # then, render this dag
        self._render_current(force=force, show_progress=show_progress)

        return self

    def build(self, force=False, show_progress=True):
        """
        Runs the DAG in order so that all upstream dependencies are run for
        every task

        Parameters
        ----------
        force: bool, optional
            If True, it will run all tasks regardless of status, defaults to
            False

        Notes
        -----
        All dag-level clients are closed after calling this function

        Returns
        -------
        BuildReport
            A dict-like object with tasks as keys and dicts with task
            status as values
        """
        if self._exec_status == DAGStatus.ErroredRender:
            raise DAGBuildError('Cannot build dag that failed to render, '
                                'fix rendering errors then build again. '
                                'To see the full traceback again, run '
                                'dag.render(force=True)')
        else:
            # at this point the DAG can only be:
            # DAGStatus.WaitingExecution, DAGStatus.Executed or
            # DAGStatus.Errored, DAGStatus.WaitingRender
            # calling render will update status to DAGStatus.WaitingExecution
            self.render(force=force, show_progress=show_progress)

            # self._clear_cached_status()

            self._logger.info('Building DAG %s', self)

            try:
                # within_dag flags when we execute a task in isolation
                # vs as part of a dag execution
                task_reports = self._executor(dag=self,
                                              show_progress=show_progress,
                                              task_kwargs=dict(within_dag=True))
            except Exception as e:
                self._exec_status = DAGStatus.Errored
                e_new = DAGBuildError('Failed to build DAG {}'.format(self))

                if self.on_failure:
                    self._logger.debug('Executing on_failure hook '
                                       'for dag "%s"', self.name)
                    kwargs_available = copy(self._available_callback_kwargs)
                    kwargs_available['traceback'] = traceback.format_exc()

                    kwargs = callback_check(self.on_failure, kwargs_available)
                    self.on_failure(**kwargs)
                else:
                    self._logger.debug('No on_failure hook for dag '
                                       '"%s", skipping', self.name)

                raise e_new from e
            else:
                self._exec_status = DAGStatus.Executed
            finally:
                # always clear out status
                self._clear_cached_status()

            # add reports from skipped tasks
            empty = [TaskReport.empty_with_name(t.name)
                     for t in self.values()
                     if t.exec_status == TaskStatus.Skipped]

            build_report = BuildReport(task_reports + empty)
            self._logger.info(' DAG report:\n{}'.format(build_report))

            if self.on_finish:
                self._logger.debug('Executing on_finish hook '
                                   'for dag "%s"', self.name)
                kwargs_available = copy(self._available_callback_kwargs)
                kwargs_available['report'] = build_report
                kwargs = callback_check(self.on_finish, kwargs_available)
                self.on_finish(**kwargs)
            else:
                self._logger.debug('No on_finish hook for dag '
                                   '"%s", skipping', self.name)

            return build_report

    def build_partially(self, target, force=False, show_progress=True):
        """Partially build a dag until certain task
        """
        lineage = self[target]._lineage
        dag = deepcopy(self)

        to_pop = set(dag) - {target}

        if lineage:
            to_pop = to_pop - lineage

        for task in to_pop:
            dag.pop(task)

        return dag.build(force=force, show_progress=show_progress)

    def status(self, **kwargs):
        """Returns a table with tasks status
        """
        # FIXME: delete this, make dag.render() return this

        # self._clear_cached_status()

        self.render()

        return Table([self._G.nodes[name]['task'].status(**kwargs)
                      for name in self._G])

    def to_dict(self, include_plot=False):
        """Returns a dict representation of the dag's Tasks,
        only includes a few attributes.

        Parameters
        ----------
        include_plot: bool, optional
            If True, the path to a PNG file with the plot in "_plot"
        """
        # self._clear_cached_status()

        d = {name: self._G.nodes[name]['task'].to_dict()
             for name in self._G}

        if include_plot:
            d['_plot'] = self.plot(open_image=False)

        return d

    def to_markup(self, path=None, fmt='html'):
        """Returns a str (md or html) with the pipeline's description
        """
        if fmt not in ['html', 'md']:
            raise ValueError('fmt must be html or md, got {}'.format(fmt))

        status = self.status().to_format('html')
        path_to_plot = Path(self.plot())
        plot = image_bytes2html(path_to_plot.read_bytes())

        template_md = importlib_resources.read_text(resources, 'dag.md')
        out = Template(template_md).render(plot=plot, status=status, dag=self)

        if fmt == 'html':
            if not mistune or not pygments:
                raise ImportError('mistune and pygments are '
                                  'required to export to HTML')

            renderer = markup.HighlightRenderer()
            out = mistune.markdown(out, escape=False, renderer=renderer)

            # add css
            html = importlib_resources.read_text(resources,
                                                 'github-markdown.html')
            out = Template(html).render(content=out)
        if path is not None:
            Path(path).write_text(out)

        return out

    @requires(['pygraphviz'])
    def plot(self, output='tmp'):
        """Plot the DAG
        """
        if output in {'tmp', 'matplotlib'}:
            path = tempfile.mktemp(suffix='.png')
        else:
            path = output

        # self._clear_cached_status()

        # attributes docs:
        # https://graphviz.gitlab.io/_pages/doc/info/attrs.html

        # FIXME: add tests for this
        self.render()

        G = self._to_graph()

        for n, data in G.nodes(data=True):
            data['color'] = 'red' if n.product._is_outdated() else 'green'
            data['label'] = n._short_repr()

        # https://networkx.github.io/documentation/networkx-1.10/reference/drawing.html
        # # http://graphviz.org/doc/info/attrs.html
        # NOTE: requires pygraphviz and pygraphviz
        G_ = nx.nx_agraph.to_agraph(G)
        G_.draw(path, prog='dot', args='-Grankdir=LR')

        if output == 'matplotlib':
            return path2fig(path)
        else:
            return path

    def _render_current(self, force, show_progress):
        """
        Render tasks, and update exec_status
        """
        if not self._cfg.cache_rendered_status or not self._did_render:
            self._logger.info('Rendering DAG %s', self)

            if show_progress:
                tasks = tqdm(self.values(), total=len(self))

            exceptions = ExceptionCollector()
            warnings_ = None

            for t in tasks:
                # no need to process task with AbortedRender
                if t.exec_status == TaskStatus.AbortedRender:
                    continue

                if show_progress:
                    tasks.set_description('Rendering DAG "{}"'
                                          .format(self.name))

                with warnings.catch_warnings(record=True) as warnings_:
                    try:
                        t.render(force=force,
                                 outdated_by_code=self._cfg.outdated_by_code)
                    except Exception:
                        tr = traceback.format_exc()
                        exceptions.append(traceback_str=tr, task_str=repr(t))

            if exceptions:
                self._exec_status = DAGStatus.ErroredRender
                raise DAGRenderError('DAG render failed, the following '
                                     'tasks could not render '
                                     '(corresponding tasks aborted '
                                     'rendering):\n{}'
                                     .format(str(exceptions)))

            # TODO: also include warnings in the exception message
            if warnings_:
                # maybe raise one by one to keep the warning type
                messages = [str(w.message) for w in warnings_]
                warning = ('Task "{}" had the following warnings:\n\n{}'
                           .format(repr(t), '\n'.join(messages)))
                warnings.warn(warning)

        self._exec_status = DAGStatus.WaitingExecution

    def _add_task(self, task):
        """Adds a task to the DAG
        """
        if task.name in self._G:
            raise ValueError('DAGs cannot have Tasks with repeated names, '
                             'there is a Task with name "{}" '
                             'already'.format(task.name))

        if task.name is not None:
            self._G.add_node(task.name, task=task)
        else:
            raise ValueError('Tasks must have a name, got None')

    def _to_graph(self, only_current_dag=False):
        """
        Converts the DAG to a Networkx DiGraph object. Since upstream
        dependencies are not required to come from the same DAG,
        this object might include tasks that are not included in the current
        object
        """
        # NOTE: delete this, use existing DiGraph object
        G = nx.DiGraph()

        for task in self.values():
            G.add_node(task)

            if only_current_dag:
                G.add_edges_from([(up, task) for up
                                  in task.upstream.values() if up.dag is self])
            else:
                G.add_edges_from([(up, task) for up in task.upstream.values()])

        return G

    def _add_edge(self, task_from, task_to):
        """Add an edge between two tasks
        """
        # if a new task is added, rendering is required again
        self._did_render = False

        if isiterable(task_from) and not isinstance(task_from, DAG):
            # if iterable, add all components as separate upstream tasks
            for a_task_from in task_from:

                # this happens when the task was originally declared in
                # another dag...
                if a_task_from.name not in self._G:
                    self._G.add_node(a_task_from.name, task=a_task_from)

                self._G.add_edge(a_task_from.name, task_to.name)

        else:
            # this happens when the task was originally declared in
            # another dag...
            if task_from.name not in self._G:
                self._G.add_node(task_from.name, task=task_from)

            # DAGs are treated like a single task
            self._G.add_edge(task_from.name, task_to.name)

    def _get_upstream(self, task_name):
        """Get upstream tasks given a task name (returns Task objects)
        """
        upstream = self._G.predecessors(task_name)
        return {u: self._G.nodes[u]['task'] for u in upstream}

    def _clear_cached_status(self):
        self._logger.debug('Clearing product status')
        # clearing out this way is only useful after building, but not
        # if the metadata changed since it wont be reloaded
        for task in self.values():
            task.product._clear_cached_status()

    def __getitem__(self, key):
        return self._G.nodes[key]['task']

    def __iter__(self):
        """Iterate task names in topological order
        """
        # TODO: raise a warning if this any of this dag tasks have tasks
        # from other tasks as dependencies (they won't show up here)
        for name in nx.algorithms.topological_sort(self._G):
            yield name

    def __len__(self):
        return len(self._G)

    def __repr__(self):
        return '{}("{}")'.format(type(self).__name__, self.name)

    def _short_repr(self):
        return repr(self)

    # IPython integration
    # https://ipython.readthedocs.io/en/stable/config/integrating.html

    def _ipython_key_completions_(self):
        return list(self)

    # __getstate__ and __setstate__ are needed to make this picklable

    def __getstate__(self):
        state = self.__dict__.copy()
        # _logger is not pickable, so we remove them and build
        # them again in __setstate__
        del state['_logger']
        return state

    def __setstate__(self, state):
        self.__dict__.update(state)
        self._logger = logging.getLogger('{}.{}'.format(__name__,
                                                        type(self).__name__))
