from abc import ABC
from functools import reduce
from typing import List, Tuple, Type, Optional, TypeVar, Union

import networkx as nx

from leaf.infrastructure import Node, Link
from leaf.power import PowerAware, PowerMeasurement


class Task(PowerAware, ABC):
    def __init__(self, mips: float):
        """Task within an :class:`Application`, i.e. a node in the application DAG.

        Args:
            mips: Million instructions per second required to execute the task.
        """
        self.id: Optional[int] = None
        self.mips = mips
        self.node: Optional[Node] = None

    def __repr__(self):
        return f"{self.__class__.__name__}(id={self.id}, mips={self.mips})"

    def allocate(self, node: Node):
        """Place the task on a certain node and allocate resources."""
        if self.node is not None:
            raise ValueError(f"Cannot place {self} on {node}: It was already placed on {self.node}.")
        self.node = node
        node.add_task(self)

    def deallocate(self):
        """Detache the task from the node it is currently placed on and deallocate resources."""
        if self.node is None:
            raise ValueError(f"{self} is not placed on any node.")
        self.node.remove_task(self)
        self.node = None

    def measure_power(self) -> PowerMeasurement:
        try:
            return self.node.measure_power().multiply(self.mips / self.node.used_mips)
        except ZeroDivisionError:
            return PowerMeasurement(0, 0)


class SourceTask(Task):
    def __init__(self, mips: float = 0, bound_node: Node = None):
        """Source task of an application that is bound to a certain node, e.g. a sensor generating data.

        Source tasks never have incoming and always have outgoing data flows.

        Args:
            mips: Million instructions per second required to execute the task.
            bound_node: The node which the task is bound to. Cannot be None.
        """
        super().__init__(mips)
        if bound_node is None:
            raise ValueError("bound_node for SourceTask cannot be None")
        self.bound_node = bound_node


class ProcessingTask(Task):
    def __init__(self, mips: float = 0):
        """Processing task of an application that can be freely placed on the infrastructure.

        Processing tasks always have incoming and outgoing data flows.

        Args:
            mips: Million instructions per second required to execute the task.
        """
        super().__init__(mips)


class SinkTask(Task):
    def __init__(self, mips: float = 0, bound_node: Node = None):
        """Sink task of an application that is bound to a certain node, e.g. a cloud server for storage.

        Args:
            mips: Million instructions per second required to execute the task.
            bound_node: The node which the task is bound to. Cannot be None.
        """
        super().__init__(mips)
        if bound_node is None:
            raise ValueError("bound_node for SourceTask cannot be None")
        self.bound_node = bound_node


class DataFlow(PowerAware):
    def __init__(self, bit_rate: float):
        """Data flow between two tasks of an application.

        Args:
            bit_rate: The bit rate of the data flow in bit/s
        """
        self.bit_rate = bit_rate
        self.links: Optional[List[Link]] = None

    def __repr__(self):
        return f"{self.__class__.__name__}(bit_rate={self.bit_rate})"

    def allocate(self, links: List[Link]):
        """Place the data flow on a path of links and allocate bandwidth."""
        if self.links is not None:
            raise ValueError(f"Cannot place {self} on {links}: It was already placed on path {self.links}.")
        self.links = links
        for link in self.links:
            link.add_data_flow(self)

    def deallocate(self):
        """Remove the data flow from the infrastructure and deallocate bandwidth."""
        if self.links is None:
            raise ValueError(f"{self} is not placed on any link.")
        for link in self.links:
            link.remove_data_flow(self)
        self.links = None

    def measure_power(self) -> PowerMeasurement:
        return PowerMeasurement.sum(link.measure_power().multiply(self.bit_rate / link.used_bandwidth)
                                    for link in self.links)


class Application(PowerAware):
    """Application consisting of one or more tasks forming a directed acyclic graph (DAG)."""
    TTask = TypeVar("TTask", bound=Task)  # Generics
    TDataFlow = TypeVar("TDataFlow", bound=DataFlow)  # Generics
    TaskTypeFilter = Union[Type[TTask], Tuple[Type[TTask], ...]]
    DataFlowTypeFilter = Union[Type[TDataFlow], Tuple[Type[TDataFlow], ...]]

    def __init__(self):
        self.graph = nx.DiGraph()

    def __repr__(self):
        return f"{self.__class__.__name__}(tasks={len(self.tasks())})"

    def add_task(self, task: Task, incoming_data_flows: List[Tuple[Task, float]] = None):
        """Add a task to the application graph.

        Args:
            task: The task to add
            incoming_data_flows: List of tuples (`src_task`, `bit_rate`) where every `src_task` is the source of a
                :class:`DataFlow` with a certain `bit_rate` to the added `task`
        """
        task.id = len(self.tasks())
        if isinstance(task, SourceTask):
            assert not incoming_data_flows, f"Source task '{task}' cannot have incoming_data_flows"
            self.graph.add_node(task.id, data=task)
        elif isinstance(task, ProcessingTask):
            assert len(incoming_data_flows) > 0, f"Processing task '{task}' has no incoming_data_flows"
            self.graph.add_node(task.id, data=task)
            for src_task, bit_rate in incoming_data_flows:
                assert not isinstance(src_task, SinkTask), f"Sink task '{task}' cannot have outgoing data flows"
                self.graph.add_edge(src_task.id, task.id, data=DataFlow(bit_rate))
        elif isinstance(task, SinkTask):
            assert len(incoming_data_flows) > 0, f"Sink task '{task}' has no incoming_data_flows"
            self.graph.add_node(task.id, data=task)
            for src_task, bit_rate in incoming_data_flows:
                assert not isinstance(src_task, SinkTask), f"Sink task '{task}' cannot have outgoing data flows"
                self.graph.add_edge(src_task.id, task.id, data=DataFlow(bit_rate))
            assert nx.is_directed_acyclic_graph(self.graph), f"Application '{self}' is no DAG"
        else:
            raise ValueError(f"Unknown task type '{type(task)}'")

    def tasks(self, type_filter: Optional[TaskTypeFilter] = None) -> List[TTask]:
        """Return all tasks in the application, optionally filtered by class."""
        task_iter = (task for _, task in self.graph.nodes.data("data"))
        if type_filter:
            task_iter = (task for task in task_iter if isinstance(task, type_filter))
        return list(task_iter)

    def data_flows(self, type_filter: Optional[DataFlowTypeFilter] = None) -> List[TDataFlow]:
        """Return all data flows in the application, optionally filtered by class."""
        df_iter = [v for _, _, v in self.graph.edges.data("data")]
        if type_filter:
            df_iter = (df for df in df_iter if isinstance(df, type_filter))
        return list(df_iter)

    def deallocate(self):
        """Detach/Unmap/Release an application from the infrastructure it is currently placed on."""
        for task in self.tasks():
            task.deallocate()
        for data_flow in self.data_flows():
            data_flow.deallocate()

    def measure_power(self) -> PowerMeasurement:
        measurements = [t.measure_power() for t in self.tasks()] + [df.measure_power() for df in self.data_flows()]
        return PowerMeasurement.sum(measurements)
