#!/usr/bin/env python
import matplotlib
import os
if os.name == 'posix' and "DISPLAY" not in os.environ:
    matplotlib.use('Agg')
import matplotlib.pyplot as plt

import argparse
import datetime
import json
import logging
import pkg_resources  # part of setuptools
import sys


import numpy as np


from copy import deepcopy
from tart.operation import settings

from tart_tools import api_imaging
from tart.imaging import elaz

from dask.distributed import Client

from disko import DiSkO, get_source_list, AdaptiveMeshSphere, create_fov, Resolution


if __name__ == '__main__':
    #parser.add_argument('--api', required=False, default='https://tart.elec.ac.nz/signal', help="Telescope API server URL.")
    #parser.add_argument('--catalog', required=False, default='https://tart.elec.ac.nz/catalog', help="Catalog API URL.")
        
    parser_mesh = argparse.ArgumentParser(add_help=False)
    parser_mesh.add_argument('--mesh', action="store_true", help="Use a non-structured mesh in the image space")
    parser_mesh.add_argument('--adaptive', type=int, default=0, help="Use N cycles of adaptive meshing")
    parser_mesh.add_argument('--res-min', type=str, default=None, help="Highest allowed res of the sky. E.g. 1.3deg, 12\", 11', 8uas, 6mas.")


    parser_sphere = argparse.ArgumentParser(add_help=False)
    parser_sphere.add_argument('--healpix', action="store_true", help="Use HealPix tiling")
    parser_sphere.add_argument('--nside', type=int, default=None, help="Healpix nside parameter for display purposes only.")


    parser = argparse.ArgumentParser(description='DiSkO: Generate an Discrete Sky Operator Image using the web api of a TART radio telescope.', 
                                    formatter_class=argparse.ArgumentDefaultsHelpFormatter,  parents=[parser_mesh, parser_sphere])


    parser.add_argument('--fov', type=str, default="180deg", help="Field of view. E.g. 1.3deg, 12\", 11', 8uas, 6mas...")
    parser.add_argument('--res', type=str, default="2deg", help="Maximim Resolution of the sky. E.g. 1.3deg, 12\", 11', 8uas, 6mas.")

    data_group = parser.add_mutually_exclusive_group()
    data_group.add_argument('--file', required=False, default=None, help="Snapshot observation saved JSON file (visiblities, positions and more).")
    data_group.add_argument('--ms', required=False, default=None, help="visibility file")
    
    parser.add_argument('--nvis', type=int, default=1000, help="Number of visibilities to use.")
    parser.add_argument('--vis', required=False, default=None, help="Use a local JSON file containing the visibilities to create the image.")
    parser.add_argument('--channel', type=int, default=0, help="Use this frequency channel.")
    parser.add_argument('--field', type=int, default=0, help="Use this FIELD_ID from the measurement set.")
    parser.add_argument('--ddid', type=int, default=0, help="Use this DDID from the measurement set.")

    algo_group = parser.add_mutually_exclusive_group()
    algo_group.add_argument('--lsqr', action="store_true", help="Use lsqr in matrix-free")
    algo_group.add_argument('--lsmr', action="store_true", help="Use lsmr in matrix-free")
    algo_group.add_argument('--fista', action="store_true", help="Use FISTA in matrix-free")
    algo_group.add_argument('--lasso', action="store_true", help="Use L1 regularization.")
    algo_group.add_argument('--tikhonov', action="store_true", help="Use L2 regularization.")
    
    parser.add_argument('--matrix-free', action="store_true", help="Use matrix-free regularization.")
    parser.add_argument('--niter', type=int, default=100, help="Number of iterations for iterative solutions.")
    
    parser.add_argument('--dir', required=False, default='.', help="Output directory.")
    parser.add_argument('--alpha', type=float, default=None, help="Regularization parameter.")
    parser.add_argument('--l1-ratio', type=float, default=0.02, help="Regularization parameter, ratio of l1 to l2 (1.0 means l1 only).")

    parser.add_argument('--show-sources', action="store_true", help="Show known sources on images (only works on PNG & SVG).")
    parser.add_argument('--title', required=False, default="disko", help="Prefix the output files.")
    parser.add_argument('--elevation', type=float, default=20.0, help="Elevation limit for displaying sources (degrees).")
    parser.add_argument('--display', action="store_true", help="Display Image to the user.")
    parser.add_argument('--PNG', action="store_true", help="Generate a PNG format image.")
    parser.add_argument('--PDF', action="store_true", help="Generate a PDF format image.")
    parser.add_argument('--SVG', action="store_true", help="Generate a SVG format image.")
    parser.add_argument('--VTK', action="store_true", help="Generate a VTK mesh format image.")
    parser.add_argument('--FITS', action="store_true", help="Generate a FITS format image.")

    parser.add_argument('--cv', action="store_true", help="Use Cross Validation")
    parser.add_argument('--dask', action="store_true", help="Use dask")
    
    parser.add_argument('--version', action="store_true", help="Display the current version")


    source_json = None

    ARGS = parser.parse_args()
        
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    fh = logging.FileHandler('disko.log')
    fh.setLevel(logging.INFO)
    
    # create console handler and set level to debug
    ch = logging.StreamHandler()
    ch.setLevel(logging.INFO)

    # create formatter
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')

    # add formatter to ch
    ch.setFormatter(formatter)
    fh.setFormatter(formatter)

    # add ch to logger
    #logger.addHandler(ch)
    logger.addHandler(fh)
    
    if ARGS.version:
        version = pkg_resources.require("disko")[0].version
        print(f"DiSkO: Version {version}")
        print("       (c) 2022 Tim Molteno")
        sys.exit(0)
    
    fov = Resolution.from_string(ARGS.fov)
    res = Resolution.from_string(ARGS.res)
    

    sphere = None
    if ARGS.mesh:
        if ARGS.res_min is None:
            res_min = res
        else:
            res_min = Resolution.from_string(ARGS.res_min)

        sphere = AdaptiveMeshSphere.from_resolution(res_min=res_min, res_max=res, theta=np.radians(0.0), phi=0.0, fov=fov)
    if ARGS.healpix:
        sphere = create_fov(ARGS.nside, fov=fov, res=res)

    if sphere is None:
        raise RuntimeError("Either --mesh or --healpix must be specified")

    if ARGS.file:
        logger.info("Getting Data from file: {}".format(ARGS.file))
        # Load data from a JSON file
        with open(ARGS.file, 'r') as json_file:
            calib_info = json.load(json_file)

        info = calib_info['info']
        ant_pos = calib_info['ant_pos']
        config = settings.from_api_json(info['info'], ant_pos)

        flag_list = [] # [4, 5, 14, 22]

        original_positions = deepcopy(config.get_antenna_positions())

        gains_json = calib_info['gains']
        gains = np.asarray(gains_json['gain'])
        phase_offsets = np.asarray(gains_json['phase_offset'])
        config = settings.from_api_json(info['info'], ant_pos)

        measurements = []
        for d in calib_info['data']:
            vis_json, source_json = d
            cv, timestamp = api_imaging.vis_calibrated(vis_json, config, gains, phase_offsets, flag_list)
            src_list = elaz.from_json(source_json, 0.0)
        disko = DiSkO.from_cal_vis(cv)
    elif ARGS.ms:
        logger.info(f"Getting Data from MS file: {ARGS.ms} to {sphere}")

        if not os.path.exists(ARGS.ms):
            raise RuntimeError(f"Measurement set {ARGS.ms} not found")

        min_res = sphere.min_res()
        logger.info(f"Min Res {min_res}")
        disko = DiSkO.from_ms(ARGS.ms, ARGS.nvis, res=min_res, channel=ARGS.channel, field_id=ARGS.field, ddid=ARGS.ddid)
        # Convert from reduced Julian Date to timestamp.
        timestamp = disko.timestamp

    else:
        logger.info("Getting Data from API: {}".format(ARGS.api))

        api = api_handler.APIhandler(ARGS.api)
        config = api_handler.get_config(api)

        gains = api.get('calibration/gain')

        if (ARGS.vis is None):
            vis_json = api.get('imaging/vis')
        else:
            with open(ARGS.vis, 'r') as json_file:
                vis_json = json.load(json_file)

        ts = api_imaging.vis_json_timestamp(vis_json)
        if ARGS.show_sources:
            source_json = api.get_url(api.catalog_url(config, datestr=ts.isoformat()))

        logger.info("Data Download Complete")

        cv, timestamp = api_imaging.vis_calibrated(vis_json, config, gains['gain'], gains['phase_offset'], flag_list=[])
        disko = DiSkO.from_cal_vis(cv)

    if not ARGS.show_sources:
        src_list = None
    # api_imaging.rotate_vis(ARGS.rotation, cv, reference_positions = deepcopy(config.get_antenna_positions()))
    
    time_repr = "{:%Y_%m_%d_%H_%M_%S_%Z}".format(timestamp)

    # Processing

    # CASAcore UVW is conjugated, so to make things consistent with data
    # streaming off telescope we need the vis flipped about
    if ARGS.ms:
        disko.vis_arr = disko.vis_arr.conjugate()
    elif ARGS.file:
        pass
    else:
        pass

    if ARGS.show_sources:
        src_list = get_source_list(source_json, el_limit=ARGS.elevation, jy_limit=1e4)

    if ARGS.lasso:
        logger.info("L1 regularization alpha=%f" %ARGS.alpha)
        sky = disko.image_lasso(disko.vis_arr, sphere, alpha=ARGS.alpha, l1_ratio=ARGS.l1_ratio, scale=False, use_cv=ARGS.cv)
    elif ARGS.matrix_free:
        logger.info("Matrix Free alpha={}".format(ARGS.alpha))
        data = disko.vis_to_data()
        sky = disko.solve_matrix_free(data, sphere, alpha=ARGS.alpha, scale=False, lsqr=ARGS.lsqr, fista=ARGS.fista, lsmr=ARGS.lsmr, niter=ARGS.niter)
    elif ARGS.tikhonov:
        logger.info("L2 regularization alpha={}".format(ARGS.alpha))
        sky = disko.image_tikhonov(disko.vis_arr, sphere, alpha=ARGS.alpha, scale=False, usedask=ARGS.dask)
       
        if ARGS.mesh:
            for i in range(ARGS.adaptive):
                sphere.write_mesh(f"{ARGS.title}_round_{i}.vtk")

                sphere.refine()
                sky = disko.image_tikhonov(disko.vis_arr, sphere, alpha=ARGS.alpha, scale=False, usedask=ARGS.dask)
                sphere.pixels = sphere.pixels / sphere.pixel_areas

    else:
        sky = disko.solve_vis(disko.vis_arr, sphere)

    image_title = f"{ARGS.title}_{time_repr}"

    def path(ending, image_title):
        os.makedirs(ARGS.dir, exist_ok=True)
        fname = '{}.{}'.format(image_title, ending)
        return os.path.join(ARGS.dir, fname)

    if ARGS.mesh:
        # Save as a VTK file
        sphere.write_mesh(path('vtk', image_title))


    def save_images(image_title, source_list):
        
        if ARGS.VTK:
            sphere.write_mesh(path('vtk', image_title))

        if ARGS.FITS:
            # Save as a FITS file
            sphere.to_fits(fname=path('fits', image_title), info=disko.info)
        
        if ARGS.SVG:
            fname = path('svg', image_title)
            sphere.to_svg(fname=fname, show_grid=True, src_list=source_list, title=image_title)
            logger.info("Generating {}".format(fname))
        if ARGS.PNG:
            fname = path('png', image_title)
            sphere.plot(plt, source_list)
            plt.title(image_title)
            plt.tight_layout()
            plt.savefig(fname, dpi=300)
            plt.close()
            logger.info("Generating {}".format(fname))
        if ARGS.PDF:
            fname = path('pdf', image_title)
            sphere.plot(plt, source_list)
            plt.title(image_title)
            plt.savefig(fname, dpi=600)
            plt.close()
            logger.info("Generating {}".format(fname))
        if ARGS.display:
            sphere.plot(plt, src_list)
            plt.title(image_title)
            plt.show()

    if ARGS.FITS or ARGS.SVG or ARGS.PNG or ARGS.PDF:
        save_images('{}_{}'.format(ARGS.title, time_repr), source_list=src_list)
    
    #if ARGS.SVG:
        #fname = '{}.svg'.format(image_title)
        #fpath = os.path.join(ARGS.dir, fname)

        ##sky = disko.image_lasso(disko.vis_arr, sphere, alpha=0.02, scale=False)
        #sphere.to_svg(fname=fpath, show_grid=True, src_list=src_list, fov=ARGS.fov, title=image_title)
        #logger.info("Generating {}".format(fname))
    #if ARGS.PNG:
        #sphere.plot(plt, src_list)
        #plt.title(image_title)
        #fname = '{}.png'.format(image_title)
        #fpath = os.path.join(ARGS.dir, fname)
        #plt.savefig(fpath, dpi=300)
        #plt.close()
        #logger.info("Generating {}".format(fname))
    #if ARGS.PDF:
        #sphere.plot(plt, src_list)
        #plt.title(image_title)
        #fname = '{}.pdf'.format(image_title)
        #fpath = os.path.join(ARGS.dir, fname)
        #plt.savefig(fpath, dpi=600)
        #plt.close()
        #logger.info("Generating {}".format(fname))
        
    #client.close()
