#!/usr/bin/env python3

from __future__ import print_function
from __future__ import division
from __future__ import absolute_import
from __future__ import unicode_literals

import numpy
from os import remove
import os.path
from time import clock
import sys

from argparse import ArgumentParser

from mangadap.drpfits import DRPFits

from matplotlib import pyplot
from astropy.io import fits

#-----------------------------------------------------------------------------

def calculate_sigma_inst_cube(plate, ifudesign, ofile, wave=None, directory_path=None,
                              dispersion_factor=None):

    print('     PLATE: {0}'.format(plate))
    print(' IFUDESIGN: {0}'.format(ifudesign))

    # Access the DRP RSS file
    print('Attempting to open RSS file:')
    drpf = DRPFits(plate, ifudesign, 'RSS', read=True, directory_path=directory_path)
    print('     FOUND: {0}'.format(drpf.file_path()))

    if wave is not None:
        channel = numpy.argsort( numpy.absolute(drpf['WAVE'].data - wave) )[0]
        print('Nearest wavelength channel has wavelength {0:.1f} ang.'.format(
                                                                    drpf['WAVE'].data[channel]))
        sig_inst = numpy.transpose(drpf.instrumental_dispersion_plane(channel,
                                   dispersion_factor=dispersion_factor))
        mask = numpy.array(numpy.invert(sig_inst>0), dtype=numpy.uint8)

        # Need to be more thoughtful with the header stuff here
        hdr = fits.Header()
        hdr['SRC'] = (drpf.file_path(), 'Source DRP file')
        hdr['CHAN'] = (channel, 'Index of wavelength channel used')
        hdr['WAVE'] = (drpf['WAVE'].data[channel], 'Wavelength at this channel')

        # Write the file
        fits.HDUList([ fits.PrimaryHDU(header=hdr),
                       fits.ImageHDU(data=sig_inst, name='DISP'),
                       fits.ImageHDU(data=mask, name='MASK') ]).writeto(ofile, clobber=True)
    else:
        nw = drpf['WAVE'].data.size                     # Number of wavelength channels
        drpf._cube_dimensions()
        siginst_cube = numpy.zeros((nw,drpf.ny,drpf.nx), dtype=numpy.float64)
        siginst_mask = numpy.zeros((nw,drpf.ny,drpf.nx), dtype=numpy.uint8)

        for i in range(nw):
            print('Constructing wavelength channel: {0}/{1}'.format(i+1,nw), end='\r')
            siginst_cube[i,:,:] = numpy.transpose(
                    drpf.instrumental_dispersion_plane(i, dispersion_factor=dispersion_factor,
                                                       quiet=True))
            siginst_mask[i,:,:] = numpy.array(numpy.invert(siginst_cube[i,:,:]>0),dtype=numpy.uint8)
        print('Constructing wavelength channel: DONE                 ')

        # Need to be more thoughtful with the header stuff here
        hdr = fits.Header()
        hdr['SRC'] = (drpf.file_path(), 'Source DRP file')

        # Write the file
        fits.HDUList([ fits.PrimaryHDU(header=hdr),
                       fits.ImageHDU(data=drpf['WAVE'].data, name='WAVE'),
                       fits.ImageHDU(data=siginst_cube, name='DISP'),
                       fits.ImageHDU(data=siginst_mask, name='MASK'),
                       fits.ImageHDU(data=drpf['SPECRES'].data, name='SPECRES')
                     ]).writeto(ofile, clobber=True)


#-----------------------------------------------------------------------------

if __name__ == '__main__':
    t = clock()

    parser = ArgumentParser()

    parser.add_argument('plate', type=int, help='plate ID to process')
    parser.add_argument('ifudesign', type=int, help='IFU design to process')
    parser.add_argument('output_file', type=str, help='Name for output file')

    parser.add_argument('-w', '--wavelength', type=float,
                        help='Wavelength for single instrumental dispersion plane', default=None)

    parser.add_argument('-d', '--directory_path', type=str,
                        help='Directory with the DRP produced RSS file; default uses environmental '
                             'variables to define the default MaNGA DRP redux path', default=None)

    parser.add_argument('-f', '--dispersion_factor', type=float,
                        help='Artificially *increase* the instrumental dispersion (a *decrease* '
                             'in the spectral resolution) by this factor before reconstructing '
                             'the dispersion cube.', default=None)

    arg = parser.parse_args()

    if os.path.isfile(arg.output_file):
        print('WARNING: Overwriting existing file {0}!'.format(arg.output_file))

    calculate_sigma_inst_cube(arg.plate, arg.ifudesign, arg.output_file, wave=arg.wavelength,
                              directory_path=arg.directory_path,
                              dispersion_factor=arg.dispersion_factor)

    print('Elapsed time: {0} seconds'.format(clock() - t))



