# -*- coding: utf-8 -*-
"""
Py_BANSHEE
Authors: Paul Koot, Miguel Angel Mendoza-Lugo, Dominik Paprotny,
         Elisa Ragno, Oswaldo Morales-Nápoles, Daniël Worm

E-mail:  m.a.mendozalugo@tudelft.nl, paulkoot6@gmail.com & O.MoralesNapoles@tudelft.nl
"""

import networkx as nx
import graphviz as gv
import seaborn as sns
import pandas as pd
from matplotlib import pyplot as plt
from IPython.display import Image
import os


def bn_visualize(parent_cell, R, names, data=None, fig_name=''):
    """ Visualize the structure of a defined Bayesian Network
 
     bn_visualize creates and saves a directed digraph presenting the
     structure of nodes and arcs of the Bayesian Network (BN), defined by
     parent_cell. The function also displays the conditional rank
     correlations at each arc defined by R.
 
 
    Parameters
    ----------
    parent_cell : list
        A list containing the structure of the BN, 
        the same as required in the bn_rankcorr function
    R : numpy.ndarray
        Rank Correlation Matrix
    names : list
        a list containing names of the nodes for the plot. Should
        be in the same order as they appear in matrix R and parent_cell
    data : pandas.core.frame.DataFrame
        the same data that can be used as input in bn_rankcorr. When this 
        argument is given as input, the nodes in the visualization contain 
        the marginal distribution of the data within each node.
    fig_name : string
        Name extension of the .png file with the Bayesian Network that
        is created: BN_visualize_'fig_name'.png. 
        The file is saved in the working directory. 
        
    Returns
    -------
    None.
    """

    G = nx.DiGraph()
    if isinstance(data, pd.DataFrame):
        for node in data:
            plt.figure()
            h = sns.histplot(data[node], kde=True)
            h.set_xlabel('')
            h.set_title('{}'.format(node), fontsize=25)
            plt.savefig('histogram_{}.png'.format(node))
            G.add_node(node, image='histogram_{}.png'.format(node),
                       fontsize=0)
            plt.show()
    else:
        G.add_nodes_from(names, style='filled', fillcolor='red')
        plt.show()

    for i in range(len(names)):
        parents = parent_cell[i]
        for j in parents:
            G.add_edge(names[j], names[i], label=("%.2f") % R[j, i],
                       fontsize=18)

    nx.drawing.nx_pydot.write_dot(G, 'BN_visualize_{}'.format(fig_name))
    # Convert dot file to png file
    gv.render('dot', 'png', 'BN_visualize_{}'.format(fig_name))

    def deleteFile(filename):
        if os.path.exists(filename) and not os.path.isdir(filename) and not os.path.islink(filename):
            os.remove(filename)

    deleteFile('BN_visualize_{}'.format(fig_name))

    return Image(filename='BN_visualize_{}'.format(fig_name) + '.png')
