#!/usr/bin/env python

import json
import argparse

import numpy as np
import networkx as nx
from networkx.readwrite import json_graph

import trilearn.auxiliary_functions
import trilearn.graph.graph as libg
import trilearn.distributions.g_intra_class as gic
from trilearn.graph.decomposable import gen_AR2_graph

np.set_printoptions(precision=1)
# G-Intra class (AR(2))


def main(s2, rho, n_samples, n_dim, graph_dir, data_dir, precmat_dir, **args):
    graph = gen_AR2_graph(n_dim)
    X = gic.sample(graph, rho, s2, n_samples).T
    c = gic.cov_matrix(graph, rho, s2)

    hm_true = np.array(nx.to_numpy_matrix(graph, nodelist=range(n_dim)))
    trilearn.auxiliary_functions.plot_matrix(hm_true, graph_dir + "/intraclass_p_" + str(n_dim) +
                   "_heatmap", "eps", "True graph")
    libg.plot_graph(graph, graph_dir + "/intraclass_p_" +
                    str(n_dim)+"_graph.eps")
    np.savetxt(data_dir + "/intraclass_p_"+str(n_dim)+"_sigma2_"+str(s2) +
               "_rho_"+str(rho)+"_n_"+str(n_samples)+".csv", X, delimiter=',')
    np.savetxt(precmat_dir + "/intraclass_p_"+str(n_dim)+"_sigma2_"+str(s2) +
               "_rho_"+str(rho)+"_omega.txt", c.I, delimiter=',')

    with open(graph_dir +
              "/intraclass_p_"+str(n_dim)+".json", 'w') as outfile:
        js_graph = json_graph.node_link_data(graph)
        json.dump(js_graph, outfile)

    print "Generated files:"
    print graph_dir + "/intraclass_p_"+str(n_dim)+"_heatmap.eps"
    print graph_dir + "/intraclass_p_"+str(n_dim)+"_graph.eps"
    print data_dir + "/intraclass_p_"+str(n_dim)+"_sigma2_"+str(s2) + \
        "_rho_"+str(rho)+"_n_"+str(n_samples)+".csv"
    print precmat_dir + "/intraclass_p_"+str(n_dim) + \
        "_sigma2_"+str(s2)+"_rho_"+str(rho)+"_omega.txt"
    print graph_dir + "/intraclass_p_"+str(n_dim)+".json"

    if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Generates samples from a graph "
                                                 "intra-class model and saves the precision matrix.")
    parser.add_argument(
        '--sigma2',
        type=float, required=True,
        description='Variance'

    )
    parser.add_argument(
        '--rho',
        type=float, required=True,
        description="Correlation coefficient"

    )
    parser.add_argument(
        '-n', '--n_samples',
        type=int, required=True,
        description="Number of normal samples"
    )
    parser.add_argument(
        '-p', '--n_dim',
        type=int, required=True,
        descritption="Number of dimensions"
    )
    parser.add_argument(
        '--graph_dir',
        required=False, default=".",
        description="Directory where to save the graph file"
    )
    parser.add_argument(
        '--data_dir',
        required=False, default=".",
        description="Directory where to save the data file"
    )
    parser.add_argument(
        '--precmat_dir',
        required=False, default=".",
        description="Directory where to save the"
    )

    args = parser.parse_args()
    main(**__dict__)

