#!/usr/bin/env python3

from __future__ import print_function
from __future__ import division
from __future__ import absolute_import
from __future__ import unicode_literals

import numpy
from os import remove
import os.path
from time import clock
import sys

import warnings

from argparse import ArgumentParser

from mangadap.drpfits import DRPFits
from mangadap.dapfile import dapfile
from mangadap.util.defaults import default_redux_path, default_analysis_path
from mangadap.util.constants import constants

from matplotlib import pyplot
from astropy.io import fits
import astropy.constants

#-----------------------------------------------------------------------------

def place_values_in_map(x, y, z, xs, dx, ys, dy, channel):
    """
    Given a grid with the bottom corner of the first pixel at (xs, ys)
    and pixels of size dxdy, place the z values at locations x, y in
    channel, where channel is expected to have the correct size upon
    input.  The x and y coordinates are expected to be sampled at the
    same rate as the grid!
    """
    
    nx = channel.shape[0]
    ny = channel.shape[1]

    i = numpy.floor( (x-xs)/dx )
    j = numpy.floor( (y-ys)/dy )

    for ii,jj,zz in zip(i,j,z):
        if ii < 0 or ii >= nx:
            print('Could not place value in map!')
            continue
        if jj < 0 or jj >= ny:
            print('Could not place value in map!')
            continue
        channel[ii,jj] = zz


def binid_map(dapf, pixelscale=0.5):
    npix = int(numpy.sqrt(dapf['DRPS'].header['NAXIS2']))

    objra = dapf[0].header['OBJRA']
    objdec = dapf[0].header['OBJDEC']
    ifura = dapf[0].header['IFURA']
    ifudec = dapf[0].header['IFUDEC']
    x_off = (ifura-objra)*numpy.cos(numpy.radians(objdec))*3600.
    y_off = (ifudec-objdec)*3600.

    output_data = numpy.full((npix, npix), -1, dtype=int)
    xs = -pixelscale*(npix+1)/2.
    place_values_in_map(-dapf['DRPS'].data['XPOS']+x_off, dapf['DRPS'].data['YPOS']-y_off,
                        dapf['DRPS'].data['BINID'], xs, pixelscale, xs, pixelscale, output_data)
    return output_data.transpose()


# Only works with the unbinned data
def dispersion_corrections(drpf_old, dapf_old, disp_hdu, avg=False, dispersion_factor=None):
    if dispersion_factor is None and disp_hdu is None:
        raise ValueError('Must provide dispersion cube if not applying dispersion factor!')

    # Get the locations of the "binned" spectra
    old_binid = binid_map(dapf_old)
    indx = old_binid > -1

    # Get the size of the on-sky image ...
#    nx = disp_hdu['DISP'].shape[1]
    drpf_old._cube_dimensions()
    nx = drpf_old.nx
    # ... and the number of wavelength channels
    nw = drpf_old['SPECRES'].data.size

    # Check that old and new cubes have the same dimensions
    if dispersion_factor is None:
        if disp_hdu['DISP'].data.shape[1] > nx:
            nn = (disp_hdu['DISP'].data.shape[1]-nx)//2
            _disp_data = disp_hdu['DISP'].data[:,nn:-nn,nn:-nn].copy()
            _mask_data = disp_hdu['MASK'].data[:,nn:-nn,nn:-nn].copy()
        elif disp_hdu['DISP'].data.shape[1] < nx:
            _disp_data = numpy.zeros((nw,nx,nx), dtype=numpy.float)
            _mask_data = numpy.zeros((nw,nx,nx), dtype=numpy.uint8)
            nn = (nx-disp_hdu['DISP'].data.shape[1])//2
            _disp_data[:,nn:-nn,nn:-nn] = disp_hdu['DISP'].data.copy()
            _mask_data[:,nn:-nn,nn:-nn] = disp_hdu['MASK'].data.copy()
        else:
            _disp_data = disp_hdu['DISP'].data.copy()
            _mask_data = disp_hdu['MASK'].data.copy()


    # Get the old instrumental dispersion
    cnst = constants()
    old_sigma_inst = astropy.constants.c.to('km/s').value/drpf_old['SPECRES'].data/cnst.sig2fwhm

    # Reshape the vectors to match the new instrumental dispersion values
    old_sigma_inst = numpy.repeat(numpy.array([ old_sigma_inst ]), nx*nx,
                                  axis=0).reshape(nx,nx,nw).transpose(2,0,1)
#    wave = numpy.repeat(numpy.array([ disp_hdu['WAVE'].data ]), nx*nx,
    wave = numpy.repeat(numpy.array([ drpf_old['WAVE'].data ]), nx*nx,
                        axis=0).reshape(nx,nx,nw).transpose(2,0,1)
    cube_mask = numpy.zeros((nw,nx,nx), dtype=numpy.uint8) \
                    if dispersion_factor is not None else _mask_data.copy()
                                                          #disp_hdu['MASK'].data.copy()

    # Get the new instrumental dispersion
    new_sigma_inst = old_sigma_inst*dispersion_factor if dispersion_factor is not None else \
                     astropy.constants.c.to('km/s').value * _disp_data / wave
                                                            #disp_hdu['DISP'].data / wave

    # Stars ------------------------------------------------------------
    # Match the dispersion mask to the stellar-continuum fitting mask
    i, j = numpy.meshgrid(numpy.arange(nx), numpy.arange(nx))
    mask = cube_mask.copy()
    mask[:,i,j] = mask[:,i,j] & (dapf_old['SMSK'].data[:,old_binid[i,j]].astype(numpy.uint8))

    # Determine the quadrature difference between the old and new
    # instrumental dispersion ...
    diff_sigma_inst = numpy.ma.masked_array( numpy.square(new_sigma_inst) \
                                             - numpy.square(old_sigma_inst), mask=mask)
    # ... and get the mean of the unmasked pixels
    stellar_diff_sigma_inst = numpy.ma.mean(diff_sigma_inst, axis=0)
    stellar_diff_sigma_inst.mask = numpy.invert(indx)

    # H-alpha ----------------------------------------------------------
    # Reset the mask back to what it was for the full dataset
    diff_sigma_inst.mask = cube_mask.copy()
    # Initialize the array and set the mask
    halpha_diff_sigma_inst = numpy.ma.zeros(stellar_diff_sigma_inst.shape, dtype=numpy.float)
    halpha_diff_sigma_inst.mask = numpy.invert(indx)
    # Get the wavelengths of the fitted line centroids
    line_centroids = dapf_old['ELOPAR'].data['RESTWAVE'][9]*(1. \
                + dapf_old['ELOFIT'].data['IKIN_EW'][:,0,9]/astropy.constants.c.to('km/s').value)

    # Don't interpolate for now, just grab the instrumental dispersion
    # in the nearest pixel
#    line_centroids = numpy.array([ numpy.argsort(numpy.fabs(disp_hdu['WAVE'].data-l))[0] \
    line_centroids = numpy.array([ numpy.argsort(numpy.fabs(drpf_old['WAVE'].data-l))[0] \
                                    for l in line_centroids ])

    halpha_diff_sigma_inst[j[indx],i[indx]] = diff_sigma_inst[
                                line_centroids[old_binid[j[indx],i[indx]]], j[indx], i[indx] ]

#    pyplot.imshow(numpy.ma.sqrt(stellar_diff_sigma_inst), origin='lower', interpolation='nearest')
#    pyplot.colorbar()
#    pyplot.show()
#
#    pyplot.imshow(numpy.ma.sqrt(halpha_diff_sigma_inst), origin='lower', interpolation='nearest')
#    pyplot.colorbar()
#    pyplot.show()
#
#    exit()

    # Return the average instrumental dispersion correction for all
    # binned spectra
    if avg:
        return numpy.mean(halpha_diff_sigma_inst[indx]), numpy.mean(stellar_diff_sigma_inst[indx])
#        return numpy.sqrt(numpy.mean(halpha_diff_sigma_inst[indx])), \
#                numpy.sqrt(numpy.mean(stellar_diff_sigma_inst[indx]))

    # Return the bin-by-bin instrumental dispersion correction for all
    # binned spectra
    srt = numpy.argsort(old_binid[indx].ravel())
    return halpha_diff_sigma_inst[indx].ravel()[srt], stellar_diff_sigma_inst[indx].ravel()[srt]
#    return numpy.sqrt(halpha_diff_sigma_inst[indx].ravel())[srt], \
#            numpy.sqrt(stellar_diff_sigma_inst[indx].ravel())[srt]


#-----------------------------------------------------------------------------

def calculate_inst_disp_table(plate, ifudesign, planid, ofile, dispersion_factor, redux_path,
                              analysis_path, new_redux_path, disp_cube_name):

    print('     PLATE: {0}'.format(plate))
    print(' IFUDESIGN: {0}'.format(ifudesign))

    # Access the DRP RSS file
    print('Attempting to open old DRP CUBE file:')
    drpf_old = DRPFits(plate, ifudesign, 'CUBE', redux_path=redux_path, read=True)
    print('     FOUND: {0}'.format(drpf_old.file_path()))

    print('\nAttempting to open old DAP CUBE-NONE file:')
    dapf_old = dapfile(plate, ifudesign, 'CUBE', 'NONE', planid, analysis_path=analysis_path,
                       read=True)
    print('     FOUND: {0}'.format(dapf_old.file_path()))

    print('\nAttempting to open CUBE_DISP file:')
    disp_file = os.path.join(new_redux_path, str(plate), 'stack',
                             'manga-{0}-{1}-LOGCUBE_{2}.fits.gz'.format(plate, ifudesign,
                                                                        disp_cube_name))

    if not os.path.isfile(disp_file):
        warnings.warn('Cannot open {0}; providing only nominal corrections!'.format(disp_file))
        disp_hdu = None
        halpha_disp_corr_full = None
        stellar_disp_corr_full = None
    else: 
        disp_hdu = fits.open(disp_file)
        print('     FOUND: {0}'.format(disp_file))

    print('Calculating nominal correction ...', end='\r')
    halpha_disp_corr_nom, stellar_disp_corr_nom \
            = dispersion_corrections(drpf_old, dapf_old, disp_hdu,
                                     dispersion_factor=dispersion_factor)
    print('Calculating nominal correction ... DONE')

    if disp_hdu is not None:
        print('Calculating full correction ...', end='\r')
        halpha_disp_corr_full, stellar_disp_corr_full \
            = dispersion_corrections(drpf_old, dapf_old, disp_hdu)
        print('Calculating full correction ... DONE')

        halpha_disp_corr_avg = numpy.full(halpha_disp_corr_full.shape,
                                          numpy.mean(halpha_disp_corr_full), dtype=numpy.float)
        stellar_disp_corr_avg = numpy.full(stellar_disp_corr_full.shape,
                                           numpy.mean(stellar_disp_corr_full), dtype=numpy.float)
    else:
        halpha_disp_corr_full = numpy.zeros(halpha_disp_corr_nom.shape, dtype=numpy.float)
        stellar_disp_corr_full = numpy.zeros(stellar_disp_corr_nom.shape, dtype=numpy.float)
        
        halpha_disp_corr_avg = numpy.zeros(halpha_disp_corr_nom.shape, dtype=numpy.float)
        stellar_disp_corr_avg = numpy.zeros(stellar_disp_corr_nom.shape, dtype=numpy.float)
        

    # Need to be more thoughtful with the header stuff here
    hdr = fits.Header()
    hdr['PLATE'] = (plate, 'Plate ID')
    hdr['IFU'] = (ifudesign, 'IFU designation')
    hdr['PLANID'] = (planid, 'DAP plan id')
#    hdr['DRPPTH'] = (redux_path, 'Old DRP redux path')
#    hdr['DAPPTH'] = (analysis_path, 'DAP analysis path')
#    hdr['NDRPPTH'] = (new_redux_path, 'New DRP redux path')

    tblhdu = fits.BinTableHDU.from_columns([
                fits.Column(name='SPXL_STR', format='E', array=stellar_disp_corr_full),
                fits.Column(name='SPXL_HAL', format='E', array=halpha_disp_corr_full),
                fits.Column(name='NOM_STR', format='E', array=stellar_disp_corr_nom),
                fits.Column(name='NOM_HAL', format='E', array=halpha_disp_corr_nom),
                fits.Column(name='AVG_STR', format='E', array=stellar_disp_corr_avg),
                fits.Column(name='AVG_HAL', format='E', array=halpha_disp_corr_avg)
                                        ], name='DISPCORR')
    # Write the file
    fits.HDUList([ fits.PrimaryHDU(header=hdr), tblhdu ]).writeto(ofile, clobber=True)



#-----------------------------------------------------------------------------

if __name__ == '__main__':
    t = clock()

    parser = ArgumentParser()

    parser.add_argument('plate', type=int, help='plate ID to process')
    parser.add_argument('ifudesign', type=int, help='IFU design to process')
    parser.add_argument('planid', type=int, help='DAP plan id')
    parser.add_argument('output_file', type=str, help='Name for output file')

    parser.add_argument('-f', '--dispersion_factor', type=float,
                        help='Factor to use for the nominal correction to the old dispersion '
                             'measurements; default is 1.1', default=1.1)

    parser.add_argument('-d', '--disp_cube_name', type=str,
                        help='Identifier of dispersion cube to use in file name - '
                             'manga-{plt}-{ifu}-LOGCUBE_{id}.fits.gz; default is {id}=DISP',
                        default='DISP')

    parser.add_argument('-r', '--redux_path', type=str,
                        help='Directory with the old DRP produced CUBE file; default uses '
                             'environmental variables to define the default MaNGA DRP redux path',
                        default=None)

    parser.add_argument('-a', '--analysis_path', type=str,
                        help='Directory with the DAP produced fits to the old DRP CUBE file; '
                             'default uses environmental variables to define the default MaNGA '
                             'DAP analysis path', default=None)

    parser.add_argument('-n', '--new_redux_path', type=str,
                        help='Directory with the new DRP produced RSS files and the DISP cube file;'
                             'default is $HOME/data/MaNGA/instdisp/v1_5_4_rss',
                        default=None)

    arg = parser.parse_args()

    redux_path = default_redux_path() if arg.redux_path is None else arg.redux_path
    if not os.path.isdir(redux_path):
        raise FileNotFoundError('No directory: {0}'.format(redux_path))

    analysis_path = default_analysis_path() if arg.analysis_path is None else arg.analysis_path
    if not os.path.isdir(analysis_path):
        raise FileNotFoundError('No directory: {0}'.format(analysis_path))

    new_redux_path = os.path.join(environ['HOME'], 'data', 'MaNGA','instdisp','v1_5_4_rss') \
                     if arg.new_redux_path is None else arg.new_redux_path
    if not os.path.isdir(new_redux_path):
        raise FileNotFoundError('No directory: {0}'.format(new_redux_path))

    if os.path.isfile(arg.output_file):
        print('WARNING: Overwriting existing file {0}!'.format(arg.output_file))

    calculate_inst_disp_table(arg.plate, arg.ifudesign, arg.planid, arg.output_file,
                              arg.dispersion_factor, redux_path, analysis_path, new_redux_path,
                              arg.disp_cube_name)

    print('Elapsed time: {0} seconds'.format(clock() - t))



