#!/usr/bin/env python3

import matplotlib.pyplot as plt
import argparse
import numpy as np
import surfex
from datetime import datetime
import sys


def parse_plot_timeseries_args(argv):
    parser = argparse.ArgumentParser("Plot timeseries from JSON time series file")
    parser.add_argument('filename', type=str, default=None, help="JSON time series file")
    parser.add_argument('-lon', type=float, default=None, help="Longitude", required=False)
    parser.add_argument('-lat', type=float, default=None, help="Latitude", required=False)
    parser.add_argument('-stid', type=str, default=None, help="Station id", required=False)
    parser.add_argument('-stationlist', type=str, default=None, help="Station list", required=False)
    parser.add_argument('-start', type=str, default=None, help="Start time (YYYYMMDDHH)", required=False)
    parser.add_argument('-end', type=str, default=None, help="End time (YYYYMMDDHH)", required=False)
    parser.add_argument('-interval', type=int, default=None, help="Interval", required=False)
    parser.add_argument('-o', '--output', dest="output", type=str, help="Input format", default=None,
                        required=False)
    if len(argv) < 3:
        parser.print_help()
        sys.exit()

    return parser.parse_args(argv)


def plot_timeseries_from_json(args):

    lon = args.lon
    lat = args.lat
    stid = args.stid
    stationlist = args.stationlist
    starttime = args.start
    if starttime is not None:
        starttime = datetime.strptime(args.start, "%Y%m%d%H")
    endtime = args.end
    if endtime is not None:
        endtime = datetime.strptime(args.end, "%Y%m%d%H")
    interval = args.interval
    filename = args.filename
    output = args.output

    if lon is None and lat is None:
        if stid is None:
            raise Exception("You must provide lon and lat or stid")
        if stationlist is None:
            raise Exception("You must provide a stationlist with the stid")
        lons, lats = surfex.Observation.get_pos_from_stid(stationlist, [stid])
        lon = lons[0]
        lat = lats[0]

    ts = surfex.TimeSeriesFromJson(filename, lons=[lon], lats=[lat], starttime=starttime, endtime=endtime,
                                   interval=interval)

    nt = len(ts.times)
    vals = np.zeros(nt)
    for i in range(0, nt):
        vals[i] = ts.values[i][0]

    ts_stid = str(ts.stids[0])
    if ts_stid == "NA" and stid is not None:
        ts_stid = stid
    plt.title("var= " + ts.varname + " lon: " + str(lon) + " lat: " + str(lat) + " stid: " + ts_stid)
    plt.plot(ts.times, vals)
    if output is None:
        plt.show()
    else:
        plt.savefig(output)


if __name__ == "__main__":

    my_args = parse_plot_timeseries_args(sys.argv[1:])
    plot_timeseries_from_json(my_args)
