#!/usr/bin/env python
from zipfile import ZipFile
from datetime import date
from pandas import read_csv, Series, date_range, concat, merge, to_datetime
from matplotlib.pyplot import show, title
from os import chdir
from pathlib import Path
from argparse import ArgumentParser, RawTextHelpFormatter as ArgFormatter


DAY_NAMES = ['monday', 'tuesday', 'wednesday', 'thursday', 'friday', 'saturday', 'sunday']


def main(sys_args):
    chdir(str(Path(__file__).resolve().parents[1]))
    parser = ArgumentParser(
        description='Display the number of trips per day of a GTFS',
        formatter_class=ArgFormatter
    )
    parser.add_argument('-g', '--gtfs', metavar='gtfs_path', dest='gtfs_path', type=str, required=True,
                        help='Input GTFS file path')
    parser.add_argument('-s', '--start', metavar='start_date', dest='start_date', type=int, required=False,
                        help='The start date of the graph in YYYYMMDD format')
    parser.add_argument('-e', '--end', metavar='end_date', dest='end_date', type=int, required=False,
                        help='The end date of the graph in YYYYMMDD format')
    args = parser.parse_args(args=sys_args)
    plot_dates_trips(compute_date_trip_number(args.gtfs_path), args.start_date, args.end_date)
    show()


# ------------------- Errors


class Error(Exception):
    '''Base class for exceptions in this module.'''
    pass


class GTFSError(Error):
    '''Exception raised for errors in the gtfs.

    Attributes:
        table -- table in which the error occurred
        message -- explanation of the error
    '''

    def __init__(self, table, message):
        self.table = table
        self.message = message


class TimeError(Error):
    '''Exception raised for errors in time input.

    Attributes:
        time -- input time in which the error occurred
        message -- explanation of the error
    '''

    def __init__(self, time, message):
        self.time = time
        self.message = message


class DateError(Error):
    '''Exception raised for errors in dates.

    Attributes:
        date -- date in which the error occurred
        message -- explanation of the error
    '''

    def __init__(self, date, message):
        self.date = date
        self.message = message

# ------------------- Zip helpers


def filename_in_zip(zip_path, file_name):
    '''Function to know if a file is present in a zip

    Args:
        zip_path (str): the path of the zip file
        file_name (str): the name of the file you are looking for
    Returns:
        True if the file is in the zip
    '''
    with ZipFile(zip_path) as myzip:
        if file_name in myzip.namelist():
            return True
        else:
            return False


def load_zip_subfile(zip_path, subfile_name):
    '''Loads a subfile of a zip into a dataframe

    Args:
        zip_path (str): the path of the zip file
        subfile_name (str): the name of the subfile
    Returns:
        A dataframe containing the informations of the subfile
    '''
    with ZipFile(zip_path) as myzip:
        with myzip.open(subfile_name) as myfile:
            df = read_csv(myfile, encoding='utf-8-sig', skipinitialspace=True)
    return df

# ------------------- Dates helpers


def yyyymmdd_to_date(date_int):
    '''Converts a YYYMMDD int date format into date format

    Args:
        date_int (int): a date in YYMMDD integer format
    Returns:
        a date
    '''

    year = date_int // 10000
    month = (date_int - year * 10000) // 100
    day = date_int % 100
    return date(year, month, day)


def time_to_seconds_since_midnight(time_string):
    '''Converts HH:MM:SS into seconds since midnight. For example '01:02:03' returns 3723

    Args:
        time_string (string): HH:MM:SS string. The leading zero of the hours may be omitted
        HH may be more than 23 if the time is on the following day
    Returns:
        (int): number of seconds since midnight
    '''
    try:
        hours, minutes, seconds = map(int, time_string.split(':'))
    except ValueError:
        raise TimeError(time_string, 'Bad HH:MM:SS')
    return hours * 3600 + minutes * 60 + seconds

# ------------------- GTFS helpers


def calculate_min_max_calendar(gtfs_path):
    '''Computes the first and last known dates of a gtfs

    Args:
        gtfs_path (str): the path to the gtfs
    Returns:
        (tuple(int, int)): (min_date, max_date). The first and last dates of the gtfs in YYYMMDD int format
    '''
    calendar = load_zip_subfile(gtfs_path, 'calendar.txt')
    min_date = calendar['start_date'].min()
    max_date = calendar['end_date'].max()

    if filename_in_zip(gtfs_path, 'calendar_dates.txt'):
        calendar_dates = load_zip_subfile(gtfs_path, 'calendar_dates.txt')
        min_date = min(min_date, calendar_dates['date'].min())
        max_date = max(max_date, calendar_dates['date'].max())
    return min_date, max_date


def count_trips_from_frequencies_row(row):
    '''Counts the number of trips according to a line of a GTFS frequencies file

    Args:
        row (dict): a dict corresponding to a line of a frequencies file in GTFS norm.
    Returns:
        (int): the number of trips
    '''
    start_time = time_to_seconds_since_midnight(row['start_time'])
    end_time = time_to_seconds_since_midnight(row['end_time'])
    headway = row['headway_secs']
    return ((end_time - start_time) // headway) + 1


def get_dates_from_calendar_for_service(calendar, service_id):
    '''Computes the service dates according to a calendar file for a specified service

    Args:
        calendar (dataframe): a dataframe containing the calendar.txt file of a GTFS
        service_id (int): the id of the service
    Returns:
        (Serie): A serie of 1 with the computed dates as index
    '''
    service = calendar.loc[calendar['service_id'] == service_id, :]
    days = Series()
    if len(service) > 0:
        if len(service) > 1:
            raise GTFSError('calendar', 'the service ' + service_id + ' appears more than one time in the table')
        else:
            row = service.to_dict(orient='records')[0]
            start_date = yyyymmdd_to_date((row['start_date']))
            end_date = yyyymmdd_to_date((row['end_date']))
            weekmask = [row[day_name] for day_name in DAY_NAMES]
            rng = date_range(start_date, end_date)
            days = Series(rng.weekday, index=rng)
            days = days.apply(lambda x: weekmask[x])
    days.name = 'occurences'
    return days.loc[days == 1]


def get_dates_from_calendar_dates_for_service(calendar_dates, service_id):
    '''Computes the exception dates according to a calendar_dates file for a specified service

    Args:
        calendar_dates (dataframe): a dataframe containing the calendar_dates.txt file of a GTFS
        service_id (int): the id of the service
    Returns:
        (Serie): A serie of exception type (1 or -1) with the computed dates as index
    '''
    service = calendar_dates.loc[calendar_dates['service_id'] == service_id].copy()
    service['effect'] = service['exception_type'].apply(lambda x: 1 if x == 1 else -1)
    service['date'] = service['date'].apply(yyyymmdd_to_date)
    service.index = to_datetime(service['date'])
    return service['effect']


def compute_dates_for_service(calendar, calendar_dates, service_id):
    '''Computes the service dates for a specified service taking into account calendar and calendar_dates

    Args:
        calendar (dataframe): a dataframe containing the calendar.txt file of a GTFS
        calendar_dates (dataframe): a dataframe containing the calendar_dates.txt file of a GTFS
        service_id (int): the id of the service
    Returns:
        (Serie): A serie of occurences number (>1) with the computed dates as index
    '''
    calendar_days = get_dates_from_calendar_for_service(calendar, service_id)
    calendar_date_days = get_dates_from_calendar_dates_for_service(calendar_dates, service_id)
    service_days = concat([calendar_days, calendar_date_days], axis=1).fillna(0)
    service_days['occurences'] = service_days['occurences'] + service_days['effect']
    return service_days.loc[service_days['occurences'] > 0, 'occurences']


# ------------------- Main function


def compute_date_trip_number(gtfs_path):
    '''Computes the number of trips for each date of a gtfs

    Args:
        gtfs_path (str): the path to the gtfs
    Returns:
        (Serie): A serie of trips number with the dates as index
    '''
    # Compute the number of trips per day for each service

    # Load trips
    trips = load_zip_subfile(gtfs_path, 'trips.txt')
    # Load frequencies
    frequencies = None
    if filename_in_zip(gtfs_path, 'frequencies.txt'):
        frequencies = load_zip_subfile(gtfs_path, 'frequencies.txt')
        frequencies['nb_occurences'] = frequencies.apply(lambda row: count_trips_from_frequencies_row(row), axis=1)
        trips = merge(trips, frequencies, how='left', on='trip_id')
        trips['nb_occurences'] = trips['nb_occurences'].fillna(1)
    else:
        trips['nb_occurences'] = 1
    # Calculate the number of trips
    daily_occurences_per_service = trips.loc[:, ['service_id', 'nb_occurences']].groupby('service_id', as_index=False).agg(sum)

    # Get all services

    # Load calendar
    calendar = load_zip_subfile(gtfs_path, 'calendar.txt')
    services = set(calendar['service_id'])
    # Load calendar_dates
    calendar_dates = None
    if filename_in_zip(gtfs_path, 'calendar_dates.txt'):
        calendar_dates = load_zip_subfile(gtfs_path, 'calendar_dates.txt')
        services = services | set(calendar_dates['service_id'])

    # Compute the total number of trips for each service date

    # Get the dates and number of trips for each service
    dates_trips = Series()
    for service_id in services:
        nb_occurence_service_in_a_day = daily_occurences_per_service.loc[
            daily_occurences_per_service['service_id'] == service_id,
            'nb_occurences'
        ].iloc[0]
        occurences_days = compute_dates_for_service(calendar, calendar_dates, service_id) * nb_occurence_service_in_a_day
        dates_trips = concat([dates_trips, occurences_days])
    # Add all services for each date
    dates_trips = dates_trips.groupby(dates_trips.index).agg(sum)

    # Generate dates for the whole GTFS

    gtfs_start_date, gtfs_end_date = calculate_min_max_calendar(gtfs_path)
    gtfs_start_date = yyyymmdd_to_date(gtfs_start_date)
    gtfs_end_date = yyyymmdd_to_date(gtfs_end_date)
    dates = date_range(start=gtfs_start_date, end=gtfs_end_date)

    # Generate a result for each date of the GTFS

    dates_trips_total = Series(0, index=dates)
    dates_trips_total = concat([dates_trips_total, dates_trips])
    dates_trips_total = dates_trips_total.groupby(dates_trips_total.index).agg(sum)
    return dates_trips_total

# ------------------- Display helpers


def plot_dates_trips(dates_trips, start_date=None, end_date=None):
    '''Computes the graph of the number of trips per day between start_date and end_date

    Args:
        dates_trips(Serie): A serie of trips number with the dates as index
        start_date (int): The start of the graph in YYYYMMDD format.
        end_date (int): The end of the graph in YYYYMMDD format.
    Returns:
        (Axes): The Axes object of the graph that displays the number of trips per day
    '''
    # Get the x axis of the graph
    gtfs_start_date = dates_trips.index.min()
    gtfs_end_date = dates_trips.index.max()
    if start_date is None:
        start_date = gtfs_start_date
    else:
        start_date = to_datetime(yyyymmdd_to_date(start_date))
    if end_date is None:
        end_date = gtfs_end_date
    else:
        end_date = to_datetime(yyyymmdd_to_date(end_date))
    if end_date < start_date:
        raise DateError(
            '(' + start_date.strftime('%Y-%m-%d') + ', ' + end_date.strftime('%Y-%m-%d') + ')',
            'The start_date is after the end_date. Note that if one is not given, the min or max of the gtfs is taken'
        )
    elif (end_date < gtfs_start_date) or (start_date > gtfs_end_date):
        raise DateError(
            '(' + start_date.strftime('%Y-%m-%d') + ', ' + end_date.strftime('%Y-%m-%d') + ')',
            (
                'The chosen time range does not contain any date of the gtfs. ' +
                'The gtfs is between ' + gtfs_start_date.strftime('%Y-%m-%d') + ' and ' + gtfs_end_date.strftime('%Y-%m-%d')
            )
        )
    # Plot the graph
    dates_trips_extract = dates_trips.loc[(dates_trips.index >= start_date) & (dates_trips.index <= end_date)].copy()
    ax = dates_trips_extract.plot(figsize=(15, 10), marker='o')
    ax.set_xlim(start_date, end_date)
    # Find mondays for vertical lines
    rng = date_range(start_date, end_date)
    weekdays = Series(rng.weekday, index=rng)
    mondays = list(weekdays.loc[weekdays == 0].index)
    for m in mondays:
        ax.axvline(m, color='grey', linestyle='--', linewidth=2)
    # Add informations
    title(
        'gtfs start date: ' + gtfs_start_date.strftime('%Y-%m-%d') + '\n' +
        'gtfs end date: ' + gtfs_end_date.strftime('%Y-%m-%d') + '\n' +
        (
            'maximum number of trips: ' + str(int(dates_trips.max())) + ' (' + dates_trips.idxmax().strftime('%Y-%m-%d') + ')'
        )
    )
    return ax


if __name__ == '__main__':
    from sys import argv
    main(argv[1:])
