from __future__ import absolute_import, print_function, unicode_literals
from builtins import dict, str
import json
from networkx import MultiDiGraph, Graph, cycle_basis
from pygraphviz import AGraph


def im_json_to_graph(im_json):
    """Return networkx graph from Kappy's influence map JSON.

    Parameters
    ----------
    im_json : dict
        A JSON dict which contains an influence map generated by Kappy.

    Returns
    -------
    graph : networkx.MultiDiGraph
        A graph representing the influence map.
    """
    # This is for kappy compatibility: as of 4.1.2, im_json is a string,
    # whereas before it was a json object
    if isinstance(im_json, str):
        im_json = json.loads(im_json)
    imap_data = im_json['influence map']['map']

    # Initialize the graph
    graph = MultiDiGraph()

    id_node_dict = {}
    # Add each node to the graph
    for node_dict in imap_data['nodes']:
        # There is always just one entry here with the node type e.g. "rule"
        # as key, and all the node data as the value
        node_type, node = list(node_dict.items())[0]
        # Add the node to the graph with its label and type
        attrs = {'fillcolor': '#b7d2ff' if node_type == 'rule' else '#cdffc9',
                 'shape': 'box' if node_type == 'rule' else 'oval',
                 'style': 'filled'}
        graph.add_node(node['label'], node_type=node_type, **attrs)
        # Save the key of the node to refer to it later
        new_key = '%s%s' % (node_type, node['id'])
        id_node_dict[new_key] = node['label']

    def add_edges(link_list, edge_sign):
        attrs = {'sign': edge_sign,
                 'color': 'green' if edge_sign == 1 else 'red',
                 'arrowhead': 'normal' if edge_sign == 1 else 'tee'}
        for link_dict in link_list:
            source = link_dict['source']
            for target_dict in link_dict['target map']:
                target = target_dict['target']
                src_id = '%s%s' % list(source.items())[0]
                tgt_id = '%s%s' % list(target.items())[0]
                graph.add_edge(id_node_dict[src_id], id_node_dict[tgt_id],
                               **attrs)

    # Add all the edges from the positive and negative influences
    add_edges(imap_data['wake-up map'], 1)
    add_edges(imap_data['inhibition map'], -1)

    return graph


def cm_json_to_networkx(cm_json):
    """Return a networkx graph from Kappy's contact map JSON.

    The networkx Graph's structure is as follows. Each monomer is represented
    as a node of type "agent", and each site is represented as a separate
    node of type "site". Edges that have type "link" connect site nodes
    whereas edges with type "part" connect monomers with their sites.

    Parameters
    ----------
    cm_json : dict
        A JSON dict which contains a contact map generated by Kappy.

    Returns
    -------
    graph : networkx.Graph
        An undirected graph representing the contact map.
    """
    cmap_data = get_cmap_data_from_json(cm_json)
    graph = Graph()
    nodes = []
    edges = []
    for node_idx, node in enumerate(cmap_data):
        nodes.append((node_idx, {'label': node['node_type'], 'type': 'agent'}))
        for site_idx, site in enumerate(node['node_sites']):
            # We map the unique ID of the site to its name
            site_key = (node_idx, site_idx)
            nodes.append((site_key, {'label': site['site_name'],
                                     'type': 'site'}))
            # Each port link is an edge from the current site to the
            # specified site
            if not site['site_type'] or not site['site_type'][0] == 'port':
                continue
            # As of kappy 4.1.2, the format of port links have changed
            # Old format: [[1, 0]], New format: [[[0, 1], 0]]
            for port_link in site['site_type'][1]['port_links']:
                port_link = tuple([link[1] if isinstance(link, list) else link
                                   for link in port_link])
                if isinstance(port_link, list):
                    port_link = port_link[1]
                edges.append((site_key, tuple(port_link), {'type': 'link'}))
            edges.append((node_idx, site_key, {'type': 'part'}))
    graph.add_nodes_from(nodes)
    graph.add_edges_from(edges)
    return graph


def get_cm_cycles(cm_graph):
    """Return cycles from a model's Kappa contact map graph representation.

    Parameters
    ----------
    cm_graph : networkx.Graph
        A networkx graph produced by cm_json_to_networkx.

    Returns
    -------
    list
        A list of base cycles found in the contact map graph. Each cycle
        is represented as a list of strings of the form Monomer(site).
    """
    cycles = cycle_basis(cm_graph)
    processed_cycles = []
    for cycle in cycles:
        processed_cycle = []
        edges = list(zip(cycle, cycle[1:] + [cycle[0]]))
        # Filter out cycles where the same site on an agent is used
        # since that represents competitive binding
        edge_types = [cm_graph.edges[e]['type'] for e in edges]
        edge_type_pairs = list(zip(edge_types,
                                   edge_types[1:] + [edge_types[0]]))
        if any([etp == ('link', 'link') for etp in edge_type_pairs]):
            continue
        # Now just keep link edges
        link_edges = [e for e in edges if cm_graph.edges[e]['type'] == 'link']
        for n1, n2 in link_edges:
            if n1 == n2:
                break
            edge = cm_graph.edges[(n1, n2)]
            if edge['type'] == 'link':
                agent1 = cm_graph.nodes[n1[0]]['label']
                agent2 = cm_graph.nodes[n2[0]]['label']
                label1 = '%s(%s)' % (agent1, cm_graph.nodes[n1]['label'])
                label2 = '%s(%s)' % (agent2, cm_graph.nodes[n2]['label'])
                processed_cycle += [label1, label2]
        else:
            processed_cycles.append(processed_cycle)
    return processed_cycles


def cm_json_to_graph(cm_json):
    """Return pygraphviz Agraph from Kappy's contact map JSON.

    Parameters
    ----------
    cm_json : dict
        A JSON dict which contains a contact map generated by Kappy.

    Returns
    -------
    graph : pygraphviz.Agraph
        A graph representing the contact map.
    """
    cmap_data = get_cmap_data_from_json(cm_json)

    # Initialize the graph
    graph = AGraph()

    # In this loop we add sites as nodes and clusters around sites to the
    # graph. We also collect edges to be added between sites later.
    edges = []
    for node_idx, node in enumerate(cmap_data):
        sites_in_node = []
        for site_idx, site in enumerate(node['node_sites']):
            # We map the unique ID of the site to its name
            site_key = (node_idx, site_idx)
            sites_in_node.append(site_key)
            graph.add_node(site_key, label=site['site_name'], style='filled',
                           shape='ellipse')
            # Each port link is an edge from the current site to the
            # specified site
            if not site['site_type'] or not site['site_type'][0] == 'port':
                continue
            # As of kappy 4.1.2, the format of port links have changed
            # Old format: [[1, 0]], New format: [[[0, 1], 0]]
            for port_link in site['site_type'][1]['port_links']:
                port_link = tuple([link[1] if isinstance(link, list) else link
                                   for link in port_link])
                if isinstance(port_link, list):
                    port_link = port_link[1]
                edge = (site_key, tuple(port_link))
                edges.append(edge)
        graph.add_subgraph(sites_in_node,
                           name='cluster_%s' % node['node_type'],
                           label=node['node_type'])

    # Finally we add the edges between the sites
    for source, target in edges:
        graph.add_edge(source, target)

    return graph


def get_cmap_data_from_json(cm_json):
    # This is for kappy compatibility: as of 4.1.2, im_json is a string,
    # whereas before it was a json object
    if isinstance(cm_json, str):
        cm_json = json.loads(cm_json)
    cmap_data = cm_json['contact map']['map']
    # As of 4.1.2 there is also an additional level of nesting in a one-element
    # list that we can unpack here
    if len(cmap_data) == 1 and isinstance(cmap_data[0], list):
        cmap_data = cmap_data[0]
    return cmap_data
