#!/usr/bin/env python3

import os
import time
import numpy

from astropy.io import fits

from mangadap.util.filter import interpolate_masked_vector
from mangadap.util.constants import DAPConstants
from mangadap.util.fileio import compress_file

#-----------------------------------------------------------------------------

def compile_test_spectra(ofile):
    
    pltifu = numpy.array(['8131-6102', '8131-3702', '8258-6102', '8329-1901', '8138-12704',
                          '7815-6101', '9494-3703'])
    x = numpy.array([[27], [22], [27], [19], [37], [37, 37], [26], [22]])
    y = numpy.array([[27], [22], [27], [14], [37], [46, 37], [26], [22]])

    root = os.path.join(os.environ['MANGA_SPECTRO_REDUX'], os.environ['MANGADRP_VER'])

    drpall_file = os.path.join(root, 'drpall-{0}.fits'.format(os.environ['MANGADRP_VER']))
    drpall = fits.open(drpall_file)[1].data

    flux = None
    ivar = None
    wave = None
    mask = None
    sres = None
    z = None

    for i, pi in enumerate(pltifu):
        print(pi)

        # Get the redshift
        indx = drpall['PLATEIFU'] == pi
        _z = drpall['NSA_Z'][indx][0]

        # Get the relevant spectra
        hdu = fits.open(os.path.join(root, pi.split('-')[0], 'stack',
                                     'manga-{0}-LOGCUBE.fits.gz'.format(pi)))
        _flux = hdu['FLUX'].data[:,y[i],x[i]]
        _ivar = hdu['IVAR'].data[:,y[i],x[i]]
        _mask = hdu['MASK'].data[:,y[i],x[i]]
        _wave = numpy.tile(hdu['wave'].data, (_flux.shape[1],1)).T
        _sres = numpy.ma.power(DAPConstants.sig2fwhm * hdu['PREDISP'].data[:,y[i],x[i]]
                                / _wave, -1)
        _sres = numpy.apply_along_axis(interpolate_masked_vector, 0, _sres)

        if flux is None:
            flux = _flux.copy()
            ivar = _ivar.copy()
            mask = _mask.copy()
            wave = _wave.copy()
            sres = _sres.copy()
            z = numpy.array([_z]*_flux.shape[1])
        else:
            flux = numpy.append(flux, _flux, axis=1)
            ivar = numpy.append(ivar, _ivar, axis=1)
            mask = numpy.append(mask, _mask, axis=1)
            wave = numpy.append(wave, _wave, axis=1)
            sres = numpy.append(sres, _sres, axis=1)
            z = numpy.append(z, numpy.array([_z]*_flux.shape[1]))

    hdr = fits.Header()
    hdr['DRPVER'] = (os.environ['MANGADRP_VER'], 'DRP Version')
    fits.HDUList([ fits.PrimaryHDU(header=hdr), fits.ImageHDU(wave[:,0], name='WAVE'),
                   fits.ImageHDU(flux.T, name='FLUX'), fits.ImageHDU(ivar.T, name='IVAR'),
                   fits.ImageHDU(mask.T, name='MASK'), fits.ImageHDU(sres.T, name='SRES'),
                   fits.ImageHDU(z, name='Z')
                 ]).writeto(ofile, overwrite=True)
    compress_file(ofile, clobber=True)
    os.remove(ofile)


#-----------------------------------------------------------------------------

if __name__ == '__main__':
    compile_test_spectra('MaNGA_test_spectra.fits')


