#!/usr/bin/env python3

"""
fringez-generate :

Generates the fringe model for all science images (and fringe images) in
the directory. The model is saved to disk.
"""
import argparse
from fringez.model import generate_models
from fringez.model import test_models
from fringez.fringe import gather_flat_fringe_maps


def main():
    """Generates the fringe model for all science images (and fringe images) in
    the directory. The model is saved to disk."""

    # Get arguments
    parser = argparse.ArgumentParser(description=__doc__,
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    arguments = parser.add_argument_group('arguments')
    arguments.add_argument('--n-components', type=int,
                           default=6,
                           help='Number of components in the PCA model.')
    arguments.add_argument('--n-samples', type=int, default=None,
                           help='Number of samples for PCA training. '
                                'If set to None the number of samples is equal to the '
                                'number of images. If set to a number less than the '
                                'number of images, images are collected into median '
                                'stacks of approximately equal size and photometric depth '
                                'prior to training.')
    arguments.add_argument('--fringe-model-name', type=str,
                           default=None,
                           help='If selected, forces the generated fringe '
                                'models to include this name')

    plotting = parser.add_argument_group('plotting')
    plotgroup = plotting.add_mutually_exclusive_group()
    plotgroup.add_argument('--plots', dest='plotFlag',
                           action='store_true',
                           help='Turn ON plotting for debugging.')
    plotgroup.add_argument('--plots-off', dest='plotFlag',
                           action='store_false',
                           help='Turn OFF plotting for debugging. DEFAULT.')
    parser.set_defaults(plotFlag=False)

    plots = parser.add_argument_group('arguments for --plots')
    plots.add_argument('--plot-idx', type=int,
                           default=None,
                           help='If selected, forces the example debug plot '
                                'to a specific image in the folder')

    parallel = parser.add_argument_group('parallelization')
    parallelgroup = parallel.add_mutually_exclusive_group()
    parallelgroup.add_argument('--single', dest='parallelFlag',
                               action='store_false',
                               help='Run in single mode. DEFAULT.')
    parallelgroup.add_argument('--parallel', dest='parallelFlag',
                               action='store_true',
                               help='Run in parallel mode. This is '
                                    'recommended when selecting '
                                    '--n-samples less than the number of images. '
                                    'Requires mpi4py.')
    parser.set_defaults(parallelFlag=False)

    args = parser.parse_args()

    if args.parallelFlag:
        from mpi4py import MPI
        comm = MPI.COMM_WORLD
        rank = comm.Get_rank()
    else:
        rank = 0

    if rank == 0:
        print('Generating fringez model')

    # Generate the fringe model from the fringe images in the directory
    fname_arr, fringe_maps_flattened, image_shape, rcid = gather_flat_fringe_maps(args.n_samples,
                                                                                  args.parallelFlag)

    if rank != 0:
        return

    generate_models(fname_arr,
                    fringe_maps_flattened,
                    image_shape,
                    rcid,
                    fringe_model_name=args.fringe_model_name,
                    n_components=args.n_components,
                    plotFlag=args.plotFlag)

    if args.plotFlag:
        # Apply the fringe model to a random image for testing
        test_models(fringe_maps_flattened,
                    image_shape,
                    rcid,
                    fringe_model_name=args.fringe_model_name,
                    n_components=args.n_components,
                    idx=args.plot_idx)


if __name__ == '__main__':
    main()
