#!/usr/bin/env python

import argparse
from multiprocessing import Process

import numpy as np

import trilearn.distributions.sequential_junction_tree_distributions as seqdist
from trilearn.pgibbs import trajectory_to_file

parser = argparse.ArgumentParser()
parser.add_argument(
    '-T', '--trajectory_length',
    type=int, required=True, nargs='+',
    help="Number of Gibbs samples"
)
parser.add_argument(
    '-N', '--n_particles',
    type=int, required=True, nargs='+',
    help="Number of SMC particles"
)
parser.add_argument(
    '-a', '--alpha', default=[0.5],
    type=float, required=False, nargs='+',
    help="Parameter for the junction tree expander"
)
parser.add_argument(
    '-b', '--beta', default=[0.5],
    type=float, required=False, nargs='+',
    help="Parameter for the junction tree expander"
)
parser.add_argument(
    '-r', '--radius', default=[None],
    type=int, required=False, nargs='+',
    help="The search neighborhood" +
         "radius for the Gibbs sampler"
)
parser.add_argument(
    '-p', '--order',
    type=int, required=True,
    help="The order of the underlying decompoasble graph"
)
parser.add_argument(
    '-s', '--seed', default=None,
    type=int, required=False,
    help="Random seed"
)
parser.add_argument(
    '-o', '--output_directory',
    required=False, default='.',
    help="Output directory"
)


args = parser.parse_args()

np.random.seed(args.seed)
p = args.order
filename_prefix = "uniform_jt_samples_p_"+str(p)


for N in args.n_particles:
    for T in args.trajectory_length:
        for radius in args.radius:
            for alpha in args.alpha:
                for beta in args.beta:
                    sd = seqdist.UniformJTDistribution(p)
                    print "Starting: "+str((N, alpha, beta, radius,
                                            T, sd, args.output_directory+"/" +
                                            filename_prefix,))

                    proc = Process(target=trajectory_to_file,
                                   args=(N, alpha, beta, radius,
                                         T, sd, args.output_directory + "/" +
                                         filename_prefix,))
                    proc.start()


for N in args.particles:
    for T in args.trajectory_length:
        for radius in args.radius:
            for alpha in args.alpha:
                for beta in args.beta:
                    proc.join()
                    print "Completed: "+str((N, alpha, beta,
                                             radius,
                                             T,
                                             filename_prefix,))
