"""Module defining Real Node types, which can be run using our DeltaPySimulator."""
from __future__ import annotations
import logging
from queue import Queue
import sys
from threading import Event
from collections import OrderedDict

from typing import (TYPE_CHECKING,
                    Any,
                    Dict,
                    List,
                    NamedTuple,
                    Optional,
                    Tuple,
                    Type,
                    Union)

import dill

from deltalanguage.data_types import (BaseDeltaType,
                                      DOptional,
                                      ForkedReturn,
                                      NoMessage,
                                      as_delta_type,
                                      delta_type)
from deltalanguage.logging import MessageLog, make_logger
from deltalanguage._utils import (NamespacedName,
                                  QueueMessage)

from .abstract_node import AbstractNode, ForkedNode
from .node_bodies import (Latency,
                          PyConstBody,
                          PyFuncBody,
                          PyMethodBody,
                          PyMigenBody)
from .port_classes import InPort, OutPort

if TYPE_CHECKING:
    from deltalanguage.runtime import DeltaQueue, DeltaPySimulator
    from .. import DeltaGraph


class RealNode(AbstractNode):
    """Class to represent a non-abstract node that can form part of
    :py:class:`DeltaGraph`.

    Parameters
    ----------
    graph : DeltaGraph
        Graph this node is a member of.
    body : Body
        Body of code this node represents.
    return_type : BaseDeltaType
        The type of what we expect the body to output.
    name : Optional[str]
        The name of the node (an index is appended to the end).
    latency : Latency
        The estimated latency for running the body.
    lvl : int
        Logging level for the node. By default logging.ERROR.
    is_autogenerated : bool
        True if this node is created automatically to provide an input to
        another node. For instance:

        .. code-block:: python

            with DeltaGraph() as graph:
                printer(42)

        has an autogenerated node that provides 42. For strict typing this
        node should send data in the same format as printer's input.
    """

    def __init__(self,
                 graph,
                 body,
                 return_type: BaseDeltaType = None,
                 name: Optional[str] = None,
                 latency: Latency = None,
                 lvl: int = logging.ERROR,
                 is_autogenerated: bool = False):
        self.graph = graph
        self.graph.add_node(self)  # Registering self with parent graph

        self.is_autogenerated = is_autogenerated

        self._body = body
        self.return_type = return_type

        idx = RealNode.get_next_index()
        if name is None:
            # set my name to the next unique available name
            self._name = f"node_{idx}"
        else:
            self._name = f"{name}_{idx}"

        if latency is None:
            self._latency = Latency(clocks=1)  # Default latency to 1 clock
        else:
            self._latency = latency

        # Ports in/out to this node
        self.in_ports: Dict[NamespacedName, InPort] = {}
        self.out_ports: List[OutPort] = []

        self.log = make_logger(lvl,
                               f"{self.__class__.__name__} {self._name}")

        if isinstance(self.return_type, ForkedReturn):
            self.fork_names = self.return_type.keys
        else:
            self.fork_names = None

        # See MessageLog for detail
        self._clock = 0

    def __str__(self) -> str:
        ret = f"{self.name}:"
        ret += ''.join([f"\n    in : {in_port}" for in_port in self.in_ports.values()])
        ret += ''.join([f"\n    out: {out_port}" for out_port in self.out_ports])
        ret += '\n'
        return ret

    def __repr__(self):
        return self.name

    def __getattr__(self, item):
        if self.fork_names and item in self.fork_names:
            return ForkedNode(self, item)
        elif self.fork_names is None:
            raise AttributeError(
                f"Cannot fetch {item} from {self.name}, as we don't "
                f"have multiple outputs. Suggest using the node on it's own."
            )
        else:
            raise AttributeError(f"Cannot fetch {item} from {self.name}")

    def __getstate__(self):
        return self.__dict__

    def __setstate__(self, state):
        self.__dict__ = state

    def add_out_port(self, port_destination: InPort, index=None):
        """Creates an out-port and adds it to my out-port store.

        Parameters
        ----------
        index : Optional[str]
            If the out-port is one of several for this node, index specifies
            what part of the output is sent via this port. If the node has
            only one output, then this is `None`.
        port_destination : InPort
            The in-port that this out-port exports to.
        """
        try:
            if issubclass(self.return_type, NoMessage):
                raise ValueError(
                    f"Cannot make an out-port on node {self.name} "
                    f"with return type \'NoMessage\'"
                )
        except TypeError:
            pass

        if index is None:  # out-port type is whole node return type
            # due to the strict typing the out type should match the in type
            if self.is_autogenerated:
                type_out = port_destination.port_type
            else:
                type_out = as_delta_type(self.return_type)

            self.out_ports.append(
                OutPort(NamespacedName(self.name, None),
                        type_out,
                        port_destination, self)
            )
        else:  # out-port type is indexed node return type
            self.out_ports.append(
                OutPort(NamespacedName(self.name, index),
                        as_delta_type(self.return_type.elem_dict[index]),
                        port_destination, self)
            )

        # If this port is going into a port on a different graph,
        # flatten this graph into said graph
        into_graph = port_destination.node.graph
        if into_graph is not self.graph:
            into_graph.flatten(self.graph)

    def add_in_port(self, arg_name: str, in_type: Type, in_port_size: int = 0):
        """Creates an in-port and adds it to my in-port store.

        Parameters
        ----------
        arg_name : str
            Name of the argument this port supplies.
        in_type : Type
            The type that is expected to be supplied for this port.
        in_port_size: int
            Maximum size of the in ports.
            If 0 then size is unlimited.

        Returns
        -------
        InPort
            The created port.
        """
        my_port = InPort(NamespacedName(self.name, arg_name),
                         as_delta_type(in_type),
                         self,
                         in_port_size)
        self.in_ports[my_port.port_name] = my_port
        return my_port

    def _create_upstream_ports(self,
                               required_in_ports: Dict[str, Type],
                               given_nodes: Dict[str, AbstractNode],
                               in_port_size: int = 0):
        """Create the ports going into this node and their
        corresponding out-ports.

        Parameters
        ----------
        required_in_ports
            Dictionary that describes what in-ports we want coming in to
            this node.
        given_nodes
            The nodes that are expected to send results to out in-ports.
        in_port_size
            The maximum size of the node's in ports. If 0 then unlimited size.
        """
        for arg_name, type_wanted in required_in_ports.items():
            in_port = self.add_in_port(arg_name, type_wanted, in_port_size)
            given_nodes[arg_name].add_out_port(in_port, None)

    def _ports_from_arguments(self,
                              required_in_params,
                              pos_in_nodes,
                              kw_in_nodes,
                              in_port_size=0):
        """Manages the creation of upstream ports by getting the given nodes
        into the correct data structure.

        Parameters
        ----------
        required_in_params
            Dictionary that describes what in-ports we want coming in to
            this node.
        pos_in_nodes
            Nodes expected to send results to this node, specified positionally.
        kw_in_nodes
            Nodes expected to send results to this node, specified by keyword.
        in_port_size
            The maximum size of the node's in ports. If 0 then unlimited size.
        """
        if not isinstance(required_in_params, OrderedDict) and pos_in_nodes:
            raise TypeError("Please specify input parameter types to blocks "
                            "via an OrderedDict if you want to specify "
                            "parameters by position")

        # wrap the given positional in_nodes up with the name of the param they
        # satisfy
        pos_in_nodes_dict = {
            param_name: given_node
            for (given_node, (param_name, _)) in zip(pos_in_nodes,
                                                     required_in_params.items())
        }

        self._create_upstream_ports(required_in_params,
                                    {**pos_in_nodes_dict, **kw_in_nodes},
                                    in_port_size=in_port_size)

    @property
    def name(self) -> str:
        return self._name

    @property
    def body(self):
        return self._body

    @property
    def latency(self) -> Latency:
        return self._latency


class PythonNode(RealNode):
    """Parent Node type for all Python constructs.

    Attributes
    ----------
    in_queues : Dict[str, DeltaQueue]
        Queues providing input(s).
    out_queues : Dict[str, DeltaQueue]
        Queues consumins output(s).
    sig_stop : Event
        Communication channel through which the runtime signals `thread_worker`
        to stop.
    err : Queue
        Communication channel through which `thread_worker` sends error
        messages to the runtime.
    node_key : Optional[str]
        Keyword argument used for providing the node to the block, included for
        debugging purposes.
    lvl : int
        Logging level for the node. By default logging.ERROR.
    """

    def __init__(self,
                 graph,
                 body,
                 return_type=None,
                 name: str = None,
                 latency: Latency = None,
                 node_key: Optional[str] = None,
                 lvl: int = logging.ERROR,
                 is_autogenerated: bool = False):
        super().__init__(graph,
                         body,
                         return_type=return_type,
                         name=name,
                         latency=latency,
                         lvl=lvl,
                         is_autogenerated=is_autogenerated)
        self.in_queues: Dict[str, DeltaQueue] = None
        self.out_queues: Dict[str, DeltaQueue] = None
        self.sig_stop: Event = None
        self.node_key = node_key

    def set_communications(self, runtime: DeltaPySimulator):
        """Get the in and out queues relating to this node, as well as the
        utility events such as sig_stop from the runtime and save them in
        the instance.

        Parameters
        ----------
        runtime : DeltaPySimulator
            A runtime instance.
        """
        self.in_queues = runtime.in_queues[self.name]
        self.out_queues = runtime.out_queues[self.name]
        self.sig_stop = runtime.sig_stop

    def _get_input(self, *args: str) -> Tuple[Union[Dict[str, Any], Any], bool]:
        """Collect input from the in queues and the stop event and return them.

        Parameters
        ----------
        args : str
            Optionally filter inputs. Only the specified ones will be received.
        """
        # filter the in_queues based on the names given
        if args:
            queues = {name: in_q for name, in_q in self.in_queues.items()
                      if name in args}
        else:
            queues = self.in_queues
        values: Dict[str, Any] = {}

        # go through the mandatory in queues and block until input is present
        for name, in_q in queues.items():
            if not in_q.optional:
                values[name] = in_q.get(block=True)
                assert isinstance(values[name], QueueMessage)
                self.msg_log.add_message(self.name, name, values[name])

        # go through the optional ones and retrieve without blocking
        for name, in_q in queues.items():
            if in_q.optional:
                values[name] = in_q.get_or_none()
                assert isinstance(values[name], QueueMessage)
                self.msg_log.add_message(self.name, name, values[name])

        # update our internal clock to the most recent logical time seen
        self._clock = max([self._clock] + [v.clk for v in values.values()])
        # unpack the inner msg
        values = {k: v.msg for k, v in values.items()}

        self.check_stop()

        return values

    def check_stop(self):
        """Check the stop signal, which can be set by the simulator or other
        threads. If set, stop the current thread.
        """
        if self.sig_stop.is_set():
            self.log.info(f"Stopped {repr(self)}.")
            sys.exit()

    def receive(self, *args: str):
        """Retrieve inputs from the input queues.

        If a compulsory input is not provided it block the further execution.

        If a stop signal is received from the runtime, the thread will
        terminate with `sys.exit()`.
        Otherwise, the inputs are returned as a dict.

        Parameters
        ----------
        args : str
            Optionally filter inputs. Only the specified ones will be received.
        """
        val = self._get_input(*args)

        # if there is just one value to return, unpack it from the dict
        if len(val) == 1 and args:
            val = list(val.values())[0]

        if val:
            self.log.info(f"<- {val}")
        return val

    def _send_output(self, ret):
        """Write output to all the `out_queues` of the node
        and check if we should stop.
        """
        self.check_stop()

        self._clock += 1

        for out_q in self.out_queues.values():
            out_q.put(QueueMessage(ret, clk=self._clock))

    def send(self, ret: Union[object, NamedTuple]):
        """Sends out the node's output(s).

        Parameters
        ----------
        ret : Union[object, NamedTuple]
            The return value. It is implied by construction that if it is a
            single object then the node has only one out port, otherwise
            a named tuple is used, with the names of the fields matching
            the names of the out ports.
        """
        if ret:
            self.log.info(f"-> {ret}")
        self._send_output(ret)

    def thread_worker(self, runtime: DeltaPySimulator):
        """Run a regular Python node.

        Waits for input on all the mandatory inputs.
        Then, de-queues the optional inputs and executes its body.
        The output is written to the appropriate output queues.

        Parameters
        ----------
        runtime : DeltaPySimulator
            A runtime instance.
        """
        self.set_communications(runtime)

        while True:
            values = self._get_input()

            # If a node keyword has been specified for debugging then add
            # the node to the arguments.
            if self.node_key:
                # the self arg is effectively a const message, so from time 0
                values[self.node_key] = self

            self.log.debug("Running...")
            try:
                ret = self.body.eval(**values)
            except NoMessage:
                continue
            except:
                raise

            self._send_output(ret)

    def run_once(self, runtime: DeltaPySimulator):
        """Compute the value of the node and pass it to the output queues.

        The output queues are
        :py:class:`ConstQueue<deltalanguage.runtime.ConstQueue>` - they will save
        the value and keep returning its deepcopy to the caller, which reduces
        load on the runtime.

        Parameters
        ----------
        runtime : DeltaPySimulator
            A runtime instance.
        """
        out_queues = runtime.out_queues[self.name]

        ret_msg = self.body.eval()

        # this part repeats self._send_output
        # TODO merge them together
        self._clock += 1
        ret = QueueMessage(ret_msg, self._clock)
        for out_q in out_queues.values():
            out_q.put(ret)

    def get_serialised_body(self):
        """Returns serialised node's body.

        Returns
        -------
        bytes
        """
        return dill.dumps(self.body, recurse=True)

    def capnp(self, capnp_node, capnp_bodies):
        """Generate ``capnp`` form of this node.

        Parameters
        ----------
        capnp_node
            The capnp object of this node.
        capnp_bodies
            List of bodies so we can check if a body is already serialised.
        """
        capnp_node.name = self.name

        # 1. save reference to the body
        if self.body is None:
            # TemplateNode case: should not point to an index in bodies
            capnp_node.body = -1
        else:
            dill_impl = self.get_serialised_body()

            for i, body in enumerate(capnp_bodies):
                if body.which() == 'python':
                    if body.python.dillImpl == dill_impl:
                        capnp_node.body = i
                        break

            else:
                body = capnp_bodies.add()
                body.init('python')
                body.python.dillImpl = dill_impl
                capnp_node.body = len(capnp_bodies)-1

        # 2. save I/O ports
        self.capnp_ports(capnp_node)

    def capnp_ports(self, capnp_node):
        """Helper method, generates capnp for in/out ports of the node.

        Parameters
        ----------
        capnp_node
            The node of the interest.
        """
        in_ports = capnp_node.init("inPorts", len(self.in_ports))
        for capnp_in_port, in_port in zip(in_ports, self.in_ports.values()):
            in_port.capnp(capnp_in_port)

        out_ports = capnp_node.init("outPorts", len(self.out_ports))
        for capnp_out_port, out_port in zip(out_ports, self.out_ports):
            out_port.capnp(capnp_out_port)

    def capnp_wiring(self, capnp_nodes, capnp_wiring):
        """Generate capnp form of this node's wires.

        Parameters
        ----------
        capnp_nodes
            List of nodes so indexes can be found.
        capnp_wiring
            List of wires so we can add our relevant wires.
        """
        for i, capnp_node in enumerate(capnp_nodes):
            if capnp_node.name == self.name:
                capnp_node_index = i
                break

        for i, out_port in enumerate(self.out_ports):
            capnp_wire = capnp_wiring.add()
            capnp_wire.srcNode = capnp_node_index
            capnp_wire.srcOutPort = i
            out_port.capnp_wiring(capnp_nodes, capnp_wire)

    def set_msg_log(self, msg_log: MessageLog):
        """Sets the log for messages received.

        Parameters
        ----------
        msg_log : MessageLog
            Instance of the message log.
        """
        self.msg_log = msg_log


class PyFuncNode(PythonNode):
    """Node to represent python functions."""

    def __init__(self,
                 graph,
                 my_func,
                 in_params,
                 out_type,
                 pos_in_nodes,
                 kw_in_nodes,
                 name: str = None,
                 node_key: Optional[str] = None,
                 latency: Latency = Latency(time=350),
                 in_port_size: int = 0,
                 lvl: int = logging.ERROR):
        super().__init__(graph,
                         return_type=out_type,
                         body=PyFuncBody(my_func),
                         name=name,
                         node_key=node_key,
                         latency=latency,
                         lvl=lvl)

        self._ports_from_arguments(
            in_params, pos_in_nodes, kw_in_nodes, in_port_size)


class PyMethodNode(PythonNode):
    """Node class to represent blocks of python that are methods that act
    on a class instance.
    """

    def __init__(self,
                 graph,
                 my_func,
                 my_instance,
                 in_params,
                 out_type,
                 pos_in_nodes,
                 kw_in_nodes,
                 name: str = None,
                 node_key: Optional[str] = None,
                 latency: Latency = Latency(time=350),
                 in_port_size: int = 0,
                 lvl: int = logging.ERROR):
        super().__init__(graph,
                         return_type=out_type,
                         body=PyMethodBody(my_func, my_instance),
                         name=name,
                         node_key=node_key,
                         latency=latency,
                         lvl=lvl)

        self._ports_from_arguments(
            in_params, pos_in_nodes, kw_in_nodes, in_port_size)


class PyMigenNode(PythonNode):
    """Node class to represent blocks of migen that are methods that act
    on a class instance.
    """

    def __init__(self,
                 graph,
                 my_func,
                 my_instance,
                 in_params,
                 out_type,
                 pos_in_nodes,
                 kw_in_nodes,
                 name: str = None,
                 node_key: Optional[str] = None,
                 latency: Latency = Latency(time=350),
                 in_port_size: int = 0,
                 lvl: int = logging.ERROR):
        super().__init__(graph,
                         return_type=out_type,
                         body=PyMigenBody(my_func, my_instance),
                         name=name,
                         node_key=node_key,
                         latency=latency,
                         lvl=lvl)

        self.instance = my_instance

        self._ports_from_arguments(
            in_params, pos_in_nodes, kw_in_nodes, in_port_size)

    def get_verilog_body(self) -> str:
        """Here we return verilog of migen logic.

        Returns
        -------
        str


        .. todo::
            Check the following statement:
            If this method is called the node is elaborated and it can be done
            only once, i.e. you cannot run it multiple times.
        """
        return str(self.instance.get_serialised_body())

    def capnp(self, capnp_node, capnp_bodies):
        """Overwrites :meth:`PythonNode.capnp`."""
        capnp_node.name = self.name

        # 1. save reference to the body
        verilog = self.get_verilog_body()

        for i, body in enumerate(capnp_bodies):
            if body.which() == 'migen':
                if body.migen.verilog == verilog:
                    capnp_node.body = i
                    break

        else:
            body = capnp_bodies.add()
            body.init('migen')
            body.migen.verilog = verilog
            capnp_node.body = len(capnp_bodies)-1

        # 2. save I/O ports
        self.capnp_ports(capnp_node)


class PyConstNode(PythonNode):
    """Node class to represent blocks of python that are evaluated only once,
    then the result is cached and continuesly sent to the output queue(s).

    .. warning::
        A common misconception about this node is that the output message is
        sent out only once, which is not true. Instead, the user should think
        about this node as a constant stream of identical messages. Thus if the
        receiving node B receives only one input from a constant node A, then
        B will be evaluated continuesly as well.

    .. warning::
        Constant node cannot have inputs marked with ``DOptional``, as it
        can cause a non-deterministic due to asynchronicity of processes.
    """

    def __init__(self,
                 graph,
                 my_func,
                 in_params,
                 out_type,
                 pos_in_nodes,
                 kw_in_nodes,
                 name: str = None,
                 latency: Latency = Latency(time=100),
                 lvl: int = logging.ERROR,
                 is_autogenerated: bool = False):
        for in_type in in_params.values():
            if isinstance(in_type, DOptional):
                raise TypeError('Optional input is not allowed '
                                'for constant nodes')
            elif not isinstance(in_type, BaseDeltaType):
                raise TypeError('Unsupported data type')

        super().__init__(graph,
                         return_type=out_type,
                         body=PyConstBody(my_func,
                                          *pos_in_nodes,
                                          **kw_in_nodes),
                         name=name,
                         latency=latency,
                         lvl=lvl,
                         is_autogenerated=is_autogenerated)

        self._ports_from_arguments(in_params, pos_in_nodes, kw_in_nodes)

    def thread_worker(self, runtime: DeltaPySimulator):
        """This kind of node does not run."""
        raise NotImplementedError


def as_node(potential_node: Union[AbstractNode, object],
            graph: DeltaGraph) -> PyConstNode:
    """Ensures argument is a node and if not makes it into a constant node.

    Parameters
    ----------
    potential_node : Union[AbstractNode, object]
        Node that could be a node or not.
    graph : DeltaGraph
        Graph the node would be in

    Returns
    -------
    PyConstNode
        Made for potential_node or potential_node as it was already a node.
    """
    if isinstance(potential_node, AbstractNode):
        return potential_node
    else:
        return PyConstNode(graph,
                           lambda: potential_node,
                           {},
                           delta_type(potential_node),
                           [],
                           {},
                           is_autogenerated=True)
