#!/usr/bin/env python

import numpy as np
import pandas as pd

import trilearn.pgibbs


def main(trajectory_length, n_particles, alphas, betas, radii, seed, parallel,
         data_filename, output_directory, reset_cache, reps, **args):
    if seed is not None:
        np.random.seed(seed)
    df = pd.read_csv(data_filename, sep=',', header=None)

    if parallel is True:
        trilearn.pgibbs.sample_trajectories_ggm_parallel(df, n_particles, trajectory_length, alphas=alphas, betas=betas,
                                                         radii=radii, reps=reps, reset_cache=reset_cache,
                                                         output_directory=output_directory)
    else:
        trilearn.pgibbs.sample_trajectories_ggm_to_file(df, n_particles, trajectory_length, alphas=alphas, betas=betas,
                                                        radii=radii, reps=reps, reset_cache=reset_cache,
                                                        output_directory=output_directory)

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser("Generate particle Gibbs trajectories och decomposable graphs.")

    parser.add_argument(
        '-M', '--trajectory_length',
        type=int, required=True, nargs='+',
        help="Number of Gibbs samples"
    )
    parser.add_argument(
        '-f', '--data_filename',
        required=True,
        help="Filename of dataset stored as row vectors och floats. "
    )
    parser.add_argument(
        '-N', '--n_particles',
        type=int, required=True, nargs='+',
        help="Number of SMC particles"
    )
    parser.add_argument(
        '-a', '--alphas',
        type=float, required=False, default=[0.5], nargs='+',
        help="Parameter for the Christmas tree algorithm"
    )
    parser.add_argument(
        '-b', '--betas',
        type=float, required=False, default=[0.5], nargs='+',
        help="Parameter for the Christmas tree algorithm"
    )
    parser.add_argument(
        '-r', '--radii',
        type=int, required=False, default=[None], nargs='+',
        help="The search neighborhood radius for the Gibbs sampler"
    )
    parser.add_argument(
        '-s', '--seed',
        type=int, required=False, default=None
    )
    parser.add_argument(
        '--reps',
        type=int, required=False, default=1,
        help='Number of trajectories to sample for each parameter setting'
    )
    parser.add_argument(
        '--parallel',
        required=False, action="store_true", default=False
    )
    parser.add_argument(
        '-o', '--output_directory',
        required=False, default=".",
        help="Output directory"
    )
    parser.add_argument(
        '--reset_cache',
        required=False, default=True, action="store_true",
        help="Reset the cache in each iteration"
    )

    args = parser.parse_args()
    main(**args.__dict__)