#!/usr/bin/env python3

import argparse
import sys
from os.path import abspath, expanduser
import os
import warnings
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
sys.path.insert(0, os.path.abspath('..'))
from specdal.collection import Collection, proximal_join

parser = argparse.ArgumentParser(description='SpecDAL Pipeline')
# io options
parser.add_argument('-i', '--input_dir', metavar='PATH', default='./',
                    action='store',
                    help='input directory containing input files')
parser.add_argument('-o', '--output_dir', metavar='PATH',
                    default='./specdal_output', action='store',
                    help='output directory to store the results')
parser.add_argument('-of', '--output_figures', action='store_true',
                    help='output figures')
parser.add_argument('-od', '--output_data', action='store_true',
                    help='output data')
parser.add_argument('-oi', '--output_individual', action='store_true',
                    help='output individual spectrum')
parser.add_argument('-oa', '--output_aggregates', action='store_true',
                    help='output group aggregates')
# optional arguments
parser.add_argument('-n', '--name', type=str, action='store',
                    default='dataset', help='name of the dataset')
# resampler
parser.add_argument('-r', '--resample', default=None,
                    choices=['slinear', 'cubic'],
                    help='interpolation method')
parser.add_argument('-rs', '--resample_spacing', metavar='SPC', nargs=1,
                    type=int, default=1,
                    help='spacing for resampling (in nm)')
## overlap stitcher
parser.add_argument('-s', '--stitch', default=None,
                    choices=['mean', 'median', 'min', 'max'],
                    help='overlap stitching method')
# jump corrector
parser.add_argument('-j', '--jump_correct', default=None,
                    choices=['additive'],
                    help='jump correction method')
parser.add_argument('-js', '--jump_correct_splices', metavar='WVL',
                    default=[1000, 1800], type=int, nargs='+',
                    help='wavelengths of jump locations')
parser.add_argument('-jr', '--jump_correct_reference', metavar='REF',
                    type=int, nargs=1, default=0, help='position of'
                    'the reference detector')
# groupby
parser.add_argument('-g', '--group_by', action='store_true',
                    help='create groups using filenames')
parser.add_argument('-gs', '--group_by_separator',
                    metavar='S', nargs=1, default='_',
                    help='separator sequence to split the file names')
parser.add_argument('-gi', '--group_by_indices', metavar='I', nargs='*', type=int,
                    help='indices of the split filenames to define a group')
parser.add_argument('--group_mean', dest='aggr', action='append_const',
                    default=[],
                    const='mean', help='calculate group means')
parser.add_argument('--group_median', dest='aggr', action='append_const',
                    const='median', help='calculate group median')
parser.add_argument('--group_std', dest='aggr', action='append_const',
                    const='std', help='calculate group standard deviation')
parser.add_argument('--group_append', dest='aggr_append', nargs='+', default=[],
                    help='append aggregates to group figures')
parser.add_argument('-d', '--debug', action='store_true')
# misc
parser.add_argument('-q', '--quiet', default=False, action='store_true')
parser.add_argument('--proximal', default=None, metavar='PATH', action='store',
                    help='path containing base measurement spectral files')

args = parser.parse_args()
if args.debug:
    print('args = {}'.format(args))

################################################################################
# main
################################################################################
VERBOSE = not args.quiet

def print_if_verbose(*args, **kwargs):
    if VERBOSE:
        print(*args, **kwargs)
    
indir = abspath(expanduser(args.input_dir))
outdir = abspath(expanduser(args.output_dir))
datadir = os.path.join(outdir, 'data')
figdir = os.path.join(outdir, 'figures')
assert os.path.exists(indir)
if args.debug is True:
    import shutil
    if os.path.isdir('./specdal_output/'):
        shutil.rmtree('./specdal_output/')
assert not os.path.exists(outdir)
# make output directories
for d in (outdir, datadir, figdir):
    os.makedirs(d)

c = Collection(name=args.name)
print_if_verbose('Reading target measurements from ' + indir)
c.read(directory=indir)

if args.proximal:
    print_if_verbose('Reading base measurements from ' + args.proximal)
    c_base = Collection(name=args.name + '_base')
    c_base.read(directory=args.proximal)

if args.resample:
    print_if_verbose('Resampling...')
    c.resample(spacing=args.resample_spacing, method=args.resample)
    if args.proximal:
        c_base.resample(spacing=args.resample_spacing, method=args.resample)
if args.stitch:
    print_if_verbose('Stitching...')
    c.stitch(method=args.stitch)
    if args.proximal:
        c_base.stitch(method=args.stitch)
if args.jump_correct:
    print_if_verbose('Jump correcting...')
    c.jump_correct(splices=args.jump_correct_splices,
                   reference=args.jump_correct_reference,
                   method=args.jump_correct)
    if args.proximal:
        c_base.jump_correct(splices=args.jump_correct_splices,
                            reference=args.jump_correct_reference,
                            method=args.jump_correct)

if args.proximal:
    print_if_verbose('Joining proximal data...')
    c = proximal_join(c_base, c, on='gps_time_tgt', direction='nearest')

groups = None
# group by
if args.group_by:
    print_if_verbose('Grouping...')
    groups = c.groupby(separator=args.group_by_separator,
                       indices=args.group_by_indices)

# output individual spectra
if args.output_individual:
    print_if_verbose('Saving individual spectrum outputs...')
    indiv_datadir = os.path.join(datadir, 'indiv')
    indiv_figdir = os.path.join(figdir, 'indiv')
    os.mkdir(indiv_datadir)
    os.mkdir(indiv_figdir)
    for spectrum in c.spectra:
        spectrum.to_csv(os.path.join(indiv_datadir, spectrum.name + '.csv'))
        spectrum.plot(legend=False)
        plt.savefig(os.path.join(indiv_figdir, spectrum.name + '.png'), bbox_inches='tight')
        plt.close()

# output whole and group data
if args.output_data:
    print_if_verbose('Saving data files...')
    c.to_csv(os.path.join(datadir, c.name + ".csv"))
    if groups:
        for group_id, group_coll in groups.items():
            group_coll.to_csv(os.path.join(datadir, group_id + '.csv'))

# calculate group aggregates
if groups:
    if len(args.aggr) > 0:
        print_if_verbose('Calculating group aggregates...')
    for aggr in args.aggr:
        append = aggr in args.aggr_append
        aggr_coll = Collection(name=c.name+'_'+aggr,
                               spectra=[getattr(group_coll, aggr)(append=append)
                                        for group_coll in groups.values()],
                               measure_type=c.measure_type)
        # output
        if args.output_aggregates:
            print_if_verbose('Saving group {} outputs...'.format(aggr))
            aggr_coll.to_csv(os.path.join(datadir, aggr_coll.name + '.csv'))
            aggr_coll.plot(legend=False)
            plt.savefig(os.path.join(figdir, aggr_coll.name + '.png'),
                        bbox_inches='tight')
            plt.close()

# output whole and group figures (possibly with aggregates appended)
if args.output_figures:
    print_if_verbose('Saving entire and grouped figure outputs...')
    c.plot(legend=False)
    plt.savefig(os.path.join(figdir, c.name + ".png"),  bbox_inches="tight")
    plt.close()
    if groups:
        for group_id, group_coll in groups.items():
            group_coll.plot(legend=False)
            plt.savefig(os.path.join(figdir, group_id + ".png"),  bbox_inches="tight")
            plt.close()
