#!/usr/bin/env python3

import matplotlib.pyplot as plt
import json
import argparse
import numpy as np
import surfex
from datetime import datetime
import sys


def parse_args_plot_field(args_in):
    parser = argparse.ArgumentParser("Plot field")
    parser.add_argument('-g', '--geo', dest="geo", type=str, help="Domain/points json geometry definition file",
                        default=None, required=True)
    parser.add_argument('-v', '--variable', dest="variable", type=str, help="Variable name", required=True)
    parser.add_argument('-i', '--inputfile', dest="inputfile", type=str, help="Input file", default=None,
                        required=False)
    parser.add_argument('-it', '--inputtype', dest="inputtype", type=str, help="Filetype", default="surfex",
                        required=False, choices=["netcdf", "grib1", "grib2", "surfex", "obs"])
    parser.add_argument('-t', '--validtime', dest="validtime", type=str, help="Valid time", default=None,
                        required=False)
    parser.add_argument('-o', '--output', dest="output", type=str, help="Output file", default=None,
                        required=False)

    grib = parser.add_argument_group('grib', 'Grib1/2 settings (-it grib1 or -it grib2)')
    grib.add_argument('--indicatorOfParameter', type=int, help="Indicator of parameter [grib1]", default=None)
    grib.add_argument('--timeRangeIndicator', type=int, help="Time range indicator [grib1]", default=0)
    grib.add_argument('--levelTyp', type=str, help="Level type [grib1/grib2]", default="sfc")
    grib.add_argument('--level', type=int, help="Level [grib1/grib2]", default=0)
    grib.add_argument('--discipline', type=int, help="Discipline [grib2]", default=None)
    grib.add_argument('--parameterCategory', type=int, help="Parameter category [grib2]", default=None)
    grib.add_argument('--parameterNumber', type=int, help="ParameterNumber [grib2]", default=None)
    grib.add_argument('--typeOfStatisticalProcessing', type=int, help="TypeOfStatisticalProcessing [grib2]",
                      default=-1)

    sfx = parser.add_argument_group('Surfex', 'Surfex settings (-it surfex)')
    sfx.add_argument('--sfx_type', type=str,  help="Surfex file type", default=None,
                     choices=[None, "forcing", "ascii", "nc", "netcdf", "texte"])

    sfx.add_argument('--sfx_patches', type=int, help="Patches [ascii/texte]", default=-1)
    sfx.add_argument('--sfx_layers', type=int, help="Layers [ascii/texte]", default=-1)
    sfx.add_argument('--sfx_datatype', type=str, help="Datatype [ascii]", choices=["string", "float", "integer"],
                     default="float")
    sfx.add_argument('--sfx_interval', type=str, help="Interval [texte]", default=None)
    sfx.add_argument('--sfx_basetime', type=str, help="Basetime [texte]", default=None)
    sfx.add_argument('--sfx_geo_input', type=str, default=None,
                     help="JSON file with domain defintion [forcing/netcdf/texte]")

    obs = parser.add_argument_group('Observations', 'Observation settings (scatter plot)')
    obs.add_argument('--obs_type', type=str, help="Observation source type (-it obs)",
                     choices=[None, "json", "bufr", "frost", "netatmo"], default=None)

    if len(args_in) == 0:
        parser.print_help()
        sys.exit()

    return parser.parse_args(args_in)


def plot_field(args):

    geo_file = args.geo
    validtime = None
    if args.validtime is not None:
        validtime = datetime.strptime(args.validtime, "%Y%m%d%H")
    variable = args.variable
    filepattern = args.inputfile
    inputtype = args.inputtype
    output = args.output

    domain_json = json.load(open(geo_file, "r"))
    geo = surfex.geo.get_geo_object(domain_json)

    contour = True
    var = "field_to_read"
    if inputtype == "grib1":

        par = args.indicatorOfParameter
        lt = args.levelType
        lev = args.level
        tri = args.timeRangeIndicator

        gribvar = surfex.Grib1Variable(par, lt, lev, tri)
        title = "grib1:" + gribvar.print_keys() + " " + validtime.strftime("%Y%m%d%H")
        var_dict = {
            "parameter": par,
            "leveltype": lt,
            "lev": lev,
            "tri": tri
        }

    elif inputtype == "grib2":

        discipline = args.discipline
        parameter_category = args.parameterCategory
        parameter_number = args.parameterNumber
        level_type = args.levelType
        level = args.level
        type_of_statistical_processing = args.typeOfStatisticalProcessing

        gribvar = surfex.grib.Grib2Variable(discipline, parameter_category, parameter_number, level_type, level,
                                            tsp=type_of_statistical_processing)
        title = inputtype + ": " + gribvar.print_keys() + " " + validtime.strftime("%Y%m%d%H")

        var_dict = {
            "discipline": discipline,
            "parameterCategory": parameter_category,
            "parameterNumber": parameter_number,
            "levelType": level_type,
            "level": level,
            "typeOfStatisticalProcessing": type_of_statistical_processing
        }

    elif inputtype == "netcdf":

        if variable is None:
            raise Exception("You must provide a variable")
        if filepattern is None:
            raise Exception("You must provide a filepattern")

        title = "netcdf: "+variable + " " + validtime.strftime("%Y%m%d%H")
        var_dict = {
            "name": variable,
            "filepattern": filepattern,
            "fcint": 10800,
            "file_inc": 10800,
            "offset": 0
        }

    elif inputtype == "surfex":

        if variable is None:
            raise Exception("You must provide a variable")
        if filepattern is None:
            raise Exception("You must provide a filepattern")

        basetime = args.sfx_basetime
        patches = args.sfx_patches
        layers = args.sfx_layers
        datatype = args.sfx_datatype
        interval = args.sfx_interval
        geo_sfx_input = args.sfx_geo_input

        sfx_var = surfex.SurfexFileVariable(variable, validtime=validtime, patches=patches, layers=layers,
                                            basetime=basetime, interval=interval, datatype=datatype)

        title = inputtype + ": " + sfx_var.print_var()
        var_dict = {
            "varname": variable,
            "filepattern": filepattern,
            "patches": patches,
            "layers": layers,
            "datatype": datatype,
            "interval": interval,
            "basetime": basetime,
            "geo_input": geo_sfx_input,
            "fcint": 10800,
            "file_inc": 10800,
            "offset": 0
        }

    elif inputtype == "obs":

        contour = False
        if variable is None:
            raise Exception("You must provide a variable")

        obs_input_type = args.obs_type
        if obs_input_type is None:
            raise Exception("You must provide an obs type")

        var_dict = {
            "filetype": obs_input_type,
            "varname": variable,
            "filepattern": filepattern,
            "filenames": [filepattern],
            "fcint": 10800,
            "file_inc": 10800,
            "offset": 0
        }
        title = inputtype + ": var=" + variable + " type=" + obs_input_type

    else:
        raise NotImplementedError

    defs = {
        var: {
            inputtype: {
                "converter": {
                    "none": var_dict
                }
            }

        }
    }
    converter_conf = defs[var][inputtype]["converter"]

    # cache = surfex.Cache(True, 3600)
    cache = None
    converter = "none"
    converter = surfex.read.Converter(converter, validtime, defs, converter_conf, inputtype, validtime)
    field = surfex.ConvertedInput(geo, var, converter).read_time_step(validtime, cache)
    # field = np.reshape(field, [geo.nlons, geo.nlats])

    if field is None:
        raise Exception("No field read")

    if geo.npoints != geo.nlons and geo.npoints != geo.nlats:
        if contour:
            field = np.reshape(field, [geo.nlons, geo.nlats])
    else:
        contour = False

    if contour:
        plt.contourf(geo.lons, geo.lats, field)
    else:
        plt.scatter(geo.lonlist, geo.latlist, c=field)

    plt.title(title)
    plt.colorbar()
    if output is None:
        plt.show()
    else:
        print("Saving figure in " + output)
        plt.savefig(output)


if __name__ == "__main__":

    my_args = parse_args_plot_field(sys.argv[1:])
    plot_field(my_args)
