#!/usr/bin/env python
import matplotlib
import os
if os.name == 'posix' and "DISPLAY" not in os.environ:
    matplotlib.use('Agg')
import matplotlib.pyplot as plt

import argparse
import datetime
import json
import logging
from copy import deepcopy

import numpy as np
import scipy.special

from tart.operation import settings

from tart_tools import api_handler
from tart_tools import api_imaging
from tart.imaging import elaz


from disko import DiSkO, get_source_list, TelescopeOperator, HealpixSubSphere, HealpixSphere, vis_to_real, MultivariateGaussian

def handle_image(args, img, title, time_repr, src_list=None):
    """ This function manages the output of an image, drawing sources e.t.c."""
    image_title = '{}_{}'.format(title, time_repr)
    plt.title(image_title)
    if args.fits:
        fname = '{}.fits'.format(image_title)
        fpath = os.path.join(args.dir, fname)
        api_imaging.save_fits_image(img, fname=fname, out_dir=args.dir, timestamp=time_repr)
        logger.info("Generating {}".format(fname))
    if args.PNG:
        fname = '{}.png'.format(image_title)
        fpath = os.path.join(args.dir, fname)
        plt.savefig(fpath, dpi=300)
        logger.info("Generating {}".format(fname))
    if args.PDF:
        fname = '{}.pdf'.format(image_title)
        fpath = os.path.join(args.dir, fname)
        plt.savefig(fpath, dpi=600)
        logger.info("Generating {}".format(fname))
    if args.display:
        plt.show()

if __name__ == '__main__':
    
    np.random.seed(42)
    
    parser = argparse.ArgumentParser(description='DiSkO: Bayesian inference of a posterior sky', 
                                    formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument('--ms', required=False, default=None, help="visibility file")
    parser.add_argument('--file', required=False, default=None, help="Snapshot observation saved JSON file (visiblities, positions and more).")
    parser.add_argument('--channel', type=int, default=0, help="Use this frequency channel.")
    parser.add_argument('--field', type=int, default=0, help="Use this FIELD_ID from the measurement set.")

    parser.add_argument('--dir', required=False, default='.', help="Output directory.")
    parser.add_argument('--nside', type=int, default=None, help="Healpix nside parameter for display purposes only.")
    parser.add_argument('--nvis', type=int, default=1000, help="Number of visibilities to use.")
    parser.add_argument('--arcmin', type=float, default=None, help="Highest allowed res of the sky in arc minutes.")
    parser.add_argument('--fov', type=float, default=180.0, help="Field of view in degrees")

    parser.add_argument('--sigma-v', type=float, default=None, help="Diagonal components of the visibility covariance. If not supplied use measurement set values")

    parser.add_argument('--PNG', action="store_true", help="Generate a PNG format image.")
    parser.add_argument('--PDF', action="store_true", help="Generate a PDF format image.")
    parser.add_argument('--SVG', action="store_true", help="Generate a SVG format image.")
    parser.add_argument('--FITS', action="store_true", help="Generate a FITS format image.")
    
    parser.add_argument('--prior', type=str, default=None, help="Load the from an HDF5 file.")
    parser.add_argument('--posterior', type=str, default=None, help="Store the posterior in HDF5 format file.")

    parser.add_argument('--nsamples', type=int, default=3, help="Number of samples to save from the posterior.")

    parser.add_argument('--title', required=False, default="disko", help="Prefix the output files.")

    source_json = None

    ARGS = parser.parse_args()

    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    fh = logging.FileHandler('disko.log')
    fh.setLevel(logging.INFO)
    
    # create console handler and set level to debug
    ch = logging.StreamHandler()
    ch.setLevel(logging.INFO)

    # create formatter
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')

    # add formatter to ch
    ch.setFormatter(formatter)
    fh.setFormatter(formatter)

    # add ch to logger
    #logger.addHandler(ch)
    logger.addHandler(fh)
    
    if ARGS.file:
        logger.info("Getting Data from file: {}".format(ARGS.file))
        # Load data from a JSON file
        with open(ARGS.file, 'r') as json_file:
            calib_info = json.load(json_file)

        info = calib_info['info']
        ant_pos = calib_info['ant_pos']
        config = settings.from_api_json(info['info'], ant_pos)

        flag_list = [] # [4, 5, 14, 22]

        original_positions = deepcopy(config.get_antenna_positions())

        gains_json = calib_info['gains']
        gains = np.asarray(gains_json['gain'])
        phase_offsets = np.asarray(gains_json['phase_offset'])
        config = settings.from_api_json(info['info'], ant_pos)
    
        measurements = []
        for d in calib_info['data']:
            vis_json, source_json = d
            cv, timestamp = api_imaging.vis_calibrated(vis_json, config, gains, phase_offsets, flag_list)
            src_list = elaz.from_json(source_json, 0.0)
        disko = DiSkO.from_cal_vis(cv)
    else:
        logger.info("Getting Data from MS file: {}".format(ARGS.ms))
        disko = DiSkO.from_ms(ARGS.ms, ARGS.nvis, res_arcmin=ARGS.arcmin, channel=ARGS.channel, field_id=ARGS.field)
        # Convert from reduced Julian Date to timestamp.
        timestamp = disko.timestamp
        src_list = None
        

    time_repr = "{:%Y_%m_%d_%H_%M_%S_%Z}".format(timestamp)

    # Processing
    
    if False:
        nside = ARGS.nside
        sphere = HealpixSphere(nside)
        real_vis = disko.vis_arr # vis_to_real(disko.vis_arr)
    else:
        radius = np.radians(ARGS.fov / 2.0)
        sphere = HealpixSubSphere.from_resolution(resolution=ARGS.arcmin, theta=np.radians(0.0), phi=0.0, radius=radius)
        real_vis = vis_to_real(disko.vis_arr)
    
    
    ##
    #
    # Do the inference, get the SVD.
    #
    ##
    to = TelescopeOperator(disko, sphere)

    if ARGS.prior is not None:
        prior = MultivariateGaussian.from_hdf5(ARGS.prior)
    else:
        prior = to.get_prior() # in the image space.
    
    natural_prior =  prior.linear_transform(to.Vh)

    n_v = real_vis.shape[0]
    
    # TODO create a proper covariance that ensures the real and imaginary components are linked.
    if ARGS.sigma_v is None:
        diag = np.diagflat(disko.rms**2)
        
        sigma_vis = np.block([[diag, 0.5*diag],[0.5*diag, diag]])
        logger.info("Using measurement set sigma {}".format(np.percentile(disko.rms, [5,50,95])))
    else:
        sigma_vis = np.identity(n_v)*(ARGS.sigma_v)**2
    
    posterior = to.sequential_inference(natural_prior, real_vis, sigma_vis)
    
    to.plot_uv(ARGS.title)
    # Now save the files.
    if ARGS.posterior is not None:
        posterior.to_hdf5(ARGS.posterior)

    def path(ending, image_title):
        fname = '{}.{}'.format(image_title, ending)
        return os.path.join(ARGS.dir, fname)

    def save_images(image_title):
        # Save as a FITS file
        if ARGS.FITS:
            sphere.to_fits(fname=path('fits', image_title), fov=ARGS.fov, info=disko.info)
        
        if ARGS.SVG:
            fname = path('svg', image_title)
            sphere.to_svg(fname=fname, show_grid=True, src_list=src_list, fov=ARGS.fov, title=image_title)
            logger.info("Generating {}".format(fname))
        if ARGS.PNG:
            fname = path('png', image_title)
            sphere.plot(plt, src_list)
            plt.title(image_title)
            plt.savefig(fname, dpi=300)
            plt.close()
            logger.info("Generating {}".format(fname))
        if ARGS.PDF:
            fname = path('pdf', image_title)
            sphere.plot(plt, src_list)
            plt.title(image_title)
            plt.savefig(fname, dpi=600)
            plt.close()
            logger.info("Generating {}".format(fname))

    if ARGS.PDF or ARGS.PNG or ARGS.SVG or ARGS.FITS:    
        mu_positive = np.clip(posterior.mu, 0, None)
        sphere.set_visible_pixels(mu_positive, scale=False)
        save_images('{}_{}_mu'.format(ARGS.title, time_repr))
        
        sphere.set_visible_pixels(posterior.variance(), scale=False)
        save_images('{}_{}_var'.format(ARGS.title, time_repr))
        
        pix_cov=posterior.sigma()[0,:]
        sphere.set_visible_pixels(pix_cov, scale=False)
        save_images('{}_{}_pcf'.format(ARGS.title, time_repr))
        
        for i in range(ARGS.nsamples):
            sphere.set_visible_pixels(posterior.sample(), scale=False)
            save_images(image_title = '{}_{}_s{:0>5}'.format(ARGS.title, time_repr, i))
