#!/usr/bin/env python
import matplotlib
import os
if os.name == 'posix' and "DISPLAY" not in os.environ:
    matplotlib.use('Agg')
import matplotlib.pyplot as plt

import argparse
import datetime
import json
import logging
from copy import deepcopy

import numpy as np

from tart.operation import settings

from tart_tools import api_handler
from tart_tools import api_imaging
from tart.imaging import elaz

from disko import DiSkO, get_source_list, HealpixSphere, HealpixSubSphere, AdaptiveMeshSphere


def handle_image(args, img, title, time_repr, src_list=None):
    """ This function manages the output of an image, drawing sources e.t.c."""
    image_title = '{}_{}'.format(title, time_repr)
    plt.title(image_title)
    if args.fits:
        fname = '{}.fits'.format(image_title)
        fpath = os.path.join(args.dir, fname)
        api_imaging.save_fits_image(img, fname=fname, out_dir=args.dir, timestamp=time_repr)
        logger.info("Generating {}".format(fname))
    if args.PNG:
        fname = '{}.png'.format(image_title)
        fpath = os.path.join(args.dir, fname)
        plt.savefig(fpath, dpi=300)
        logger.info("Generating {}".format(fname))
    if args.PDF:
        fname = '{}.pdf'.format(image_title)
        fpath = os.path.join(args.dir, fname)
        plt.savefig(fpath, dpi=600)
        logger.info("Generating {}".format(fname))
    if args.display:
        plt.show()

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='DiSkO: Generate an Discrete Sky Operator Image using the web api of a TART radio telescope.', 
                                    formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--api', required=False, default='https://tart.elec.ac.nz/signal', help="Telescope API server URL.")
    parser.add_argument('--catalog', required=False, default='https://tart.elec.ac.nz/catalog', help="Catalog API URL.")

    parser.add_argument('--file', required=False, default=None, help="Snapshot observation saved JSON file (visiblities, positions and more).")
    parser.add_argument('--ms', required=False, default=None, help="visibility file")
    parser.add_argument('--vis', required=False, default=None, help="Use a local JSON file containing the visibilities to create the image.")
    parser.add_argument('--dir', required=False, default='.', help="Output directory.")
    parser.add_argument('--alpha', type=float, default=0.001, help="Regularization parameter. (this is divided by the sqrt of the number of pixels.")
    parser.add_argument('--nside', type=int, default=None, help="Healpix nside parameter for display purposes only.")
    parser.add_argument('--nvis', type=int, default=1000, help="Number of visibilities to use.")
    parser.add_argument('--arcmin', type=float, default=None, help="Highest allowed res of the sky in arc minutes.")
    parser.add_argument('--arcmax', type=float, default=None, help="Starting Resolution of the sky in arc minutes.")
    parser.add_argument('--fov', type=float, default=180.0, help="Field of view in degrees")

    parser.add_argument('--elevation', type=float, default=20.0, help="Elevation limit for displaying sources (degrees).")
    parser.add_argument('--display', action="store_true", help="Display Image to the user.")
    parser.add_argument('--adaptive', type=int, default=0, help="Use adaptive mesh.")
    parser.add_argument('--channel', type=int, default=0, help="Use this frequency channel.")
    parser.add_argument('--PNG', action="store_true", help="Generate a PNG format image.")
    parser.add_argument('--PDF', action="store_true", help="Generate a PDF format image.")
    parser.add_argument('--SVG', action="store_true", help="Generate a SVG format image.")

    parser.add_argument('--cv', action="store_true", help="Use Cross Validation")
    parser.add_argument('--dask', action="store_true", help="Use dask")
    parser.add_argument('--lsqr', action="store_true", help="Use lsqr in matrix-free")
    parser.add_argument('--lsmr', action="store_true", help="Use lsmr in matrix-free")
    parser.add_argument('--fista', action="store_true", help="Use FISTA in matrix-free")
    parser.add_argument('--lasso', action="store_true", help="Use L1 regularization.")
    parser.add_argument('--tikhonov', action="store_true", help="Use L2 regularization.")
    parser.add_argument('--matrix-free', action="store_true", help="Use matrix-free regularization.")
    parser.add_argument('--show-sources', action="store_true", help="Show known sources on images (only works on PNG).")

    parser.add_argument('--title', required=False, default="disko", help="Prefix the output files.")

    source_json = None

    ARGS = parser.parse_args()

    logger = logging.getLogger()
    logger.setLevel(logging.DEBUG)

    fh = logging.FileHandler('disko.log')
    fh.setLevel(logging.INFO)
    
    # create console handler and set level to debug
    ch = logging.StreamHandler()
    ch.setLevel(logging.INFO)

    # create formatter
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')

    # add formatter to ch
    ch.setFormatter(formatter)
    fh.setFormatter(formatter)

    # add ch to logger
    logger.addHandler(ch)
    logger.addHandler(fh)
    
    if ARGS.file:
        logger.info("Getting Data from file: {}".format(ARGS.file))
        # Load data from a JSON file
        with open(ARGS.file, 'r') as json_file:
            calib_info = json.load(json_file)

        info = calib_info['info']
        ant_pos = calib_info['ant_pos']
        config = settings.from_api_json(info['info'], ant_pos)

        flag_list = [] # [4, 5, 14, 22]

        original_positions = deepcopy(config.get_antenna_positions())

        gains_json = calib_info['gains']
        gains = np.asarray(gains_json['gain'])
        phase_offsets = np.asarray(gains_json['phase_offset'])
        config = settings.from_api_json(info['info'], ant_pos)
    
        measurements = []
        for d in calib_info['data']:
            vis_json, source_json = d
            cv, timestamp = api_imaging.vis_calibrated(vis_json, config, gains, phase_offsets, flag_list)
            src_list = elaz.from_json(source_json, 0.0)
        disko = DiSkO.from_cal_vis(cv)
    elif ARGS.ms:
        logger.info("Getting Data from MS file: {}".format(ARGS.ms))
        disko = DiSkO.from_ms(ARGS.ms, ARGS.nvis, res_arcmin=ARGS.arcmin, channel=ARGS.channel)
        # Convert from reduced Julian Date to timestamp.
        timestamp = disko.timestamp
        
    else:
        logger.info("Getting Data from API: {}".format(ARGS.api))

        api = api_handler.APIhandler(ARGS.api)
        config = api_handler.get_config(api)

        gains = api.get('calibration/gain')

        if (ARGS.vis is None):
            vis_json = api.get('imaging/vis')
        else:
            with open(ARGS.vis, 'r') as json_file:
                vis_json = json.load(json_file)

        ts = api_imaging.vis_json_timestamp(vis_json)
        if ARGS.show_sources:
            source_json = api.get_url(api.catalog_url(config, datestr=ts.isoformat()))

        logger.info("Data Download Complete")

        cv, timestamp = api_imaging.vis_calibrated(vis_json, config, gains['gain'], gains['phase_offset'], flag_list=[])
        disko = DiSkO.from_cal_vis(cv)

    if not ARGS.show_sources:
        src_list = None
    #api_imaging.rotate_vis(ARGS.rotation, cv, reference_positions = deepcopy(config.get_antenna_positions()))
    
    time_repr = "{:%Y_%m_%d_%H_%M_%S_%Z}".format(timestamp)

    # Processing
    
    
    if ARGS.show_sources:
        src_list = get_source_list(source_json, el_limit=ARGS.elevation, jy_limit=1e4)
    
    nside = ARGS.nside
    radius = np.radians(ARGS.fov / 2.0)
    if ARGS.adaptive > 0:
        sphere = AdaptiveMeshSphere.from_resolution(res_arcmin=ARGS.arcmin, res_arcmax=ARGS.arcmax, theta=np.radians(0.0), phi=0.0, radius=radius)
    elif nside is None:
        sphere = HealpixSubSphere.from_resolution(resolution=ARGS.arcmin, theta=np.radians(0.0), phi=0.0, radius=radius)
    else:
        sphere = HealpixSubSphere.from_resolution(nside=nside, theta=0.0, phi=0.0, radius=radius)
    
    if ARGS.lasso:
        logger.info("L1 regularization alpha=%f" %ARGS.alpha)
        sky = disko.image_lasso(disko.vis_arr, sphere, alpha=ARGS.alpha, scale=True, use_cv=ARGS.cv)
    elif ARGS.matrix_free:
        logger.info("Matrix Free alpha={}".format(ARGS.alpha))
        data = np.zeros((disko.n_v, 1, 1), dtype=np.complex128)
        data[:,0,0] = disko.vis_arr
        sky = disko.solve_matrix_free(data, sphere, alpha=ARGS.alpha, scale=True, lsqr=ARGS.lsqr, fista=ARGS.fista, lsmr=ARGS.lsmr)
    elif ARGS.tikhonov:
        logger.info("L2 regularization alpha={}".format(ARGS.alpha))
        sky = disko.image_tikhonov(disko.vis_arr, sphere, alpha=ARGS.alpha, scale=True, usedask=ARGS.dask)
        for i in range(ARGS.adaptive):
                sphere.write_mesh("round_{}.vtk".format(i))

                sphere.refine()
                sky = disko.image_tikhonov(disko.vis_arr, sphere, alpha=ARGS.alpha, scale=True, usedask=ARGS.dask)
                sphere.pixels = sphere.pixels / sphere.pixel_areas

    else:
        sky = disko.solve_vis(disko.vis_arr, sphere)

    image_title = '{}_{}'.format(ARGS.title, time_repr)

    # Save as a FITS file
    if ARGS.adaptive:
        sphere.write_mesh("{}.vtk".format(image_title))
    fname = '{}.fits'.format(image_title)
    fpath = os.path.join(ARGS.dir, fname)
    sphere.to_fits(fname=fpath, fov=ARGS.fov, info=disko.info)
    
    if ARGS.SVG:
        fname = '{}.svg'.format(image_title)
        fpath = os.path.join(ARGS.dir, fname)

        #sky = disko.image_lasso(disko.vis_arr, sphere, alpha=0.02, scale=False)
        sphere.to_svg(fname=fpath, show_grid=True, src_list=src_list, fov=ARGS.fov, title=image_title)
        logger.info("Generating {}".format(fname))
    if ARGS.PNG:
        sphere.plot(plt, src_list)
        plt.title(image_title)
        fname = '{}.png'.format(image_title)
        fpath = os.path.join(ARGS.dir, fname)
        plt.savefig(fpath, dpi=300)
        plt.close()
        logger.info("Generating {}".format(fname))
    if ARGS.PDF:
        sphere.plot(plt, src_list)
        plt.title(image_title)
        fname = '{}.pdf'.format(image_title)
        fpath = os.path.join(ARGS.dir, fname)
        plt.savefig(fpath, dpi=600)
        plt.close()
        logger.info("Generating {}".format(fname))
    if ARGS.display:
        sphere.plot(plt, src_list)
        plt.title(image_title)
        plt.show()
