#!python

import sys
import argparse
import logging
import json
import multiprocessing
from datetime import datetime

from urllib.parse import urlparse
from statistics import mode, StatisticsError
from fractions import Fraction

import numpy
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import requests
import requests_cache
import tldextract

from requests_futures.sessions import FuturesSession
from requests.adapters import HTTPAdapter
from requests.exceptions import ConnectionError, TooManyRedirects

from sklearn.metrics import auc

from aiu import ArchiveItCollection
from aiu import convert_LinkTimeMap_to_dict
from aiu import generate_archiveit_urits
from aiu import get_uri_responses

cpu_count = multiprocessing.cpu_count()

logger = logging.getLogger()
loglevel = logging.INFO
logging.basicConfig( 
    format='%(asctime)s [%(levelname)s] %(name)s: %(message)s',
    level=loglevel)

def parse_data_for_mementos_list(timemap_data):

    mementos = []

    for urit in timemap_data:

        if "original_uri" in timemap_data[urit]:

            urir = timemap_data[urit]["original_uri"]

            for memento in timemap_data[urit]["mementos"]["list"]:

                urim = memento["uri"]
                mdt = memento["datetime"]

                mementos.append( (mdt, urim, urir) )

    mementos.sort()

    return mementos

def convert_mementos_list_into_mdts_pct_urim_pct_and_urir_pct(mementos, enddate=None):

    logger = logging.getLogger(__name__)
    
    mdts = []
    urims = []
    urirs = []

    mdts_pct = []
    urims_pct = []
    urirs_pct = []

    urimcount = 0
    urircount = 0

    urimtotal = 0
    urirtotal = 0

    logger.info("counting URI-Ms and URI-Rs")

    for memento in mementos:

        mdt = memento[0]
        urim = memento[1]
        urir = memento[2]

        if enddate:
            
            if mdt < enddate:
        
                mdts.append(mdt)

                if urim not in urims:
                    urimtotal += 1
                    urims.append(urim)

                if urir not in urirs:
                    urirtotal += 1
                    urirs.append(urir)
                    
        else:
            
            mdts.append(mdt)

            if urim not in urims:
                urimtotal += 1
                urims.append(urim)

            if urir not in urirs:
                urirtotal += 1
                urirs.append(urir)

    logger.info("There are {} URI-Rs total".format(urirtotal))
    logger.info("There are {} URI-Ms total".format(urimtotal))

    firstmdt = min(mdts)

    logger.info("first memento-datetime: {}".format(firstmdt))

    logger.info("Calculating end date, taking specified enddate of {} "
        "into consideration".format(enddate))

    if enddate:
        lastmdt = enddate
    else:
        lastmdt = max(mdts)

    logger.info("last memento-datetime: {}".format(lastmdt))

    total_seconds = (lastmdt - firstmdt).total_seconds()

    logger.info("Total seconds is {}".format(total_seconds))

    if total_seconds == 0:

        if firstmdt == lastmdt:

            urimlen = len(urims)

            mdts = []
            urims = []
            urirs = []

            for i in range(0, urimlen):
                mdts_pct.append(1.0)
                urims_pct.append(1.0)
                urirs_pct.append(1.0)

            logger.warn("only 1 memento-datetime in collection, total seconds is 0")
            logger.warn("creating a list of {} mementos at 100%".format(urimlen))

        else:
            raise Exception("something strange happened")

    else:
        mdts = []
        urims = []
        urirs = []

        for memento in mementos:

            mdt = memento[0]
            urim = memento[1]
            urir = memento[2]

            if mdt < lastmdt:
            
                mdts_pct.append( (mdt - firstmdt).total_seconds() / total_seconds )

                # if the urim has not been seen yet, then it is new, so increment
                if urim not in urims:
                    urimcount += 1
                    urims.append(urim)

                urims_pct.append(urimcount / urimtotal)

                # if the urir has not been seen yet, then it is new, so increment
                if urir not in urirs:
                    urircount += 1
                    urirs.append(urir)

                urirs_pct.append(urircount / urirtotal)

            else:
                break

    logger.info("max datetime percentages: {}".format(max(mdts_pct)))

    # when enddate is set, we may reach the maximum % of URI-Rs
    # and URI-Ms before we get to the maximum % of time
    if max(mdts_pct) < 1.0:
        logger.info("adding an additional data point of 1.0 to all percentages")
        mdts_pct.append(1.0)
        urirs_pct.append(1.0)
        urims_pct.append(1.0)

    logger.info("memento records: {}".format(len(mementos)))
    logger.info("# MDT records: {}".format(len(mdts_pct)))
    logger.info("# URI-R records: {}".format(len(urirs_pct)))
    logger.info("# URI-M records: {}".format(len(urims_pct)))

    return mdts_pct, urims_pct, urirs_pct


def draw_both_axes_pct_growth(
        collection_id, collection_name, collected_by,
        mdts_pct, urims_pct, urirs_pct, outputfile,
        shape_percentage_of_whole=None,
        whole_text_x=None, whole_text_y=None, shape_name=None,
        enddate=None
    ):

    fig, ax = plt.subplots(1)
    plt.subplots_adjust(wspace=0.4, hspace=0.4)
    fig.set_figheight(10)
    fig.set_figwidth(10)

    labels = []
    label = ax.plot(mdts_pct, urims_pct, label="% seed mementos", color="#ff0000", linewidth=3.0)
    labels.append(label)

    labels = []
    label = ax.plot(mdts_pct, urirs_pct, label="% seeds", color="#00ff00", linewidth=3.0)
    labels.append(label)

    if shape_percentage_of_whole:
        assert whole_text_x
        assert whole_text_y
        fmt_pct = "{0:.2f}%".format(shape_percentage_of_whole * 100)
        ax.text(whole_text_x, whole_text_y, "{}\nof collections\nhave the behavior\n{}".format(
                fmt_pct, shape_name),
               fontsize=32, color="#777777")

    xmin, xmax = plt.xlim()
    ymin, ymax = plt.ylim()

    plt.title("Collection {}:\n{}\nCollected by {}".format(
        collection_id, collection_name, collected_by), fontsize=20)

    if enddate:
        plt.xlabel("Time Percentage from Start of Collection to {}".format(enddate))
    else:
        plt.xlabel("Time Percentage of Life of Collection", fontsize=20)

    plt.ylabel("URI Percentage", fontsize=18)

    ax.set_ylim(-.05, 1.05)
    ax.set_xlim(-.05, 1.05)

    ax.set_xticks(numpy.arange(0, 1.05, 0.1))
    ax.set_yticks(numpy.arange(0, 1.05, 0.1))

    # thanks https://stackoverflow.com/questions/31357611/format-y-axis-as-percent
    xvals = ax.get_xticks()
    ax.set_xticklabels(['{:3.0f}%'.format(x*100) for x in xvals])
    
    yvals = ax.get_yticks()
    ax.set_yticklabels(['{:3.0f}%'.format(y*100) for y in yvals])
    
    for tick in ax.xaxis.get_major_ticks():
        tick.label.set_fontsize(16)
        
    for tick in ax.yaxis.get_major_ticks():
        tick.label.set_fontsize(16)

    handles, labels = ax.get_legend_handles_labels()
    lgd = ax.legend(handles, labels, loc="upper left", fontsize=20)
    
    plt.savefig(outputfile)

def dtconverter(o):

    if isinstance(o, datetime):
        return o.__str__()

def process_arguments(args):

    parser = argparse.ArgumentParser(prog="{}".format(args[0],
        description="*THIS SOFTWARE IS EXPERIMENTAL* Produces seed statistics for an Archive-It collection or a list of URI-Ms.",
        formatter_class=argparse.RawTextHelpFormatter))

    group = parser.add_mutually_exclusive_group(required=True)

    group.add_argument('-c', '--collection_id', dest="collection_id", 
        help="The identifier of the Archive-It collection to process.")

    group.add_argument('-i', '--file_listing_urims', dest="file_listing_urims",
        help="The name of a file listing URI-Ms to process for statistics.")

    parser.add_argument('-o', '--outputfile', dest='outputfile',
        help="The name of the file to write the results to.", 
        required=True)

    parser.add_argument('-x', '--exclude-pattern', dest='exclude_pattern',
        help="The pattern of the URI-R to exclude from calculations.",
        default=None, required=False)

    parser.add_argument('-g', '--growth-curve-file', dest='growthcurve_filename',
        help="If present, draw a growth curve and write it to the filename specified.",
        default=None, required=False)

    parser.add_argument('-cf', dest='cachefile',
        help="The SQLite file to use for caching",
        default="/tmp/generate_seed_statistics")

    args = parser.parse_args()

    return args

def calculate_domain_diversity(uri_list):
    
    domains = []

    for uri in uri_list:

        ext = tldextract.extract(uri)
        domains.append( ext.registered_domain )

    u = len(set(domains))
    n = len(domains)

    return (u - 1) / (n - 1)

def path_depth(uri):

    o = urlparse(uri)

    score = len( [ i for i in o.path.split('/') if i != '' ] )

    if o.query != '':
        score += 1

    return score

def calculate_path_depth_diversity(uri_list):

    depths = []

    for uri in uri_list:

        depth = path_depth(uri)
        depths.append( depth )

    u = len(set(depths))
    n = len(depths)

    return (u - 1) / (n - 1)

def most_frequent_seed_uri_path_depth(uri_list):

    depths = []

    for uri in uri_list:

        depth = path_depth(uri)
        depths.append( depth )

    try:
        mf = mode(depths)
    except StatisticsError:

        depthcount = []

        for i in depths:
            depthcount.append( (depths.count(i), i) )

        depthcount.sort(reverse=True)

        mf = depthcount[0][1]

    return mf

def calculate_top_level_path_percentage(uri_list):
    
    depths = []

    for uri in uri_list:

        depth = path_depth(uri)
        depths.append( depth )

    return depths.count(0) / len(depths)

def list_generator(input_list):
    """This function generates the next item in a list. It is useful for lists
    that have their items deleted while one is iterating through them.
    """

    logger.debug("list generator called")

    while len(input_list) > 0:
        for item in input_list:
            logger.debug("list now has {} items".format(len(input_list)))
            logger.debug("yielding {}".format(item))
            yield item

def calculate_number_of_mementos(timemap_data):

    totalcount = 0

    for urit in timemap_data:

        try:
            totalcount += len(timemap_data[urit]['mementos']['list'])
        except KeyError:
            logger.exception("cannot incorporate mementos into total count from URI-T {}".format(urit))
        
    return totalcount

def calculate_percentage_querystring(uri_list):

    qscore = 0

    for uri in uri_list:
        o = urlparse(uri)

        if o.query != '':
            qscore += 1

    return qscore / len(uri_list)

def calculate_memento_seed_ratio(timemap_data):

    memcount = calculate_number_of_mementos(timemap_data)
    seedcount = len(timemap_data)

    memcount = Fraction(memcount, seedcount).numerator
    seedcount = Fraction(memcount, seedcount).denominator

    return "{}:{}".format(memcount, seedcount)

def calculate_mementos_per_seed(timemap_data):

    memcount = calculate_number_of_mementos(timemap_data)
    seedcount = len(timemap_data)

    return memcount / seedcount

def get_datetimes_list(timemap_data):

    datetimes = []

    for urit in timemap_data:

        tm = timemap_data[urit]

        try:

            for mem in tm['mementos']['list']:

                datetimes.append( mem['datetime'] )

        except KeyError:
            logger.exception("cannot acquire datetimes from URI-T: {}".format(urit))

    return datetimes

def get_first_memento_datetime(timemap_data):

    datetimes = get_datetimes_list(timemap_data)

    return min(datetimes)

def get_last_memento_datetime(timemap_data):

    datetimes = get_datetimes_list(timemap_data)

    return max(datetimes)

def process_timemaps_for_mementos(urit_list):

    timemap_data = {}
    errors_data = {}

    with FuturesSession(max_workers=cpu_count) as session:
        futures = get_uri_responses(session, urit_list)

    working_uri_list = list(futures.keys())

    for urit in list_generator(working_uri_list):

        logger.debug("checking if URI-T {} is done downloading".format(urit))

        if futures[urit].done():

            logger.debug("URI-T {} is done, extracting content".format(urit))

            try:
                response = futures[urit].result()

                http_status = response.status_code

                if http_status == 200:

                    timemap_content = response.text

                    logger.info("adding TimeMap content for URI-T {}".format(
                        urit))

                    timemap_data[urit] = convert_LinkTimeMap_to_dict(
                        timemap_content, skipErrors=True)

                else:

                    errors_data[urit] = {
                        "type": "http_error",
                        "data": response
                    }

                working_uri_list.remove(urit)

            except ConnectionError as e:

                logger.warning("There was a connection error while attempting "
                    "to download URI-T {}".format(urit))

                errors_data[urit] = {
                    "type": "exception",
                    "data": e
                }
                
                working_uri_list.remove(urit)

            except TooManyRedirects as e:

                logger.warning("There were too many redirects while attempting "
                    "to download URI-T {}".format(urit))

                errors_data[urit] = {
                    "type": "exception",
                    "data": e
                }

                working_uri_list.remove(urit)

    return timemap_data, errors_data

def get_collection_metadata_and_timemaps(collection_id):
    pass

if __name__ == '__main__':

    # if input is an Archive-It collection, fetch it
    # get the list of seeds
    # calculate seed statistics

    args = process_arguments(sys.argv)

    logger.info("Starting Archive-It seed statistics generation *NOTE THAT THIS IS EXPERIMENTAL CODE*")
    logger.info("Using collection ID {}".format(args.collection_id))

    requests_cache.install_cache(args.cachefile, backend='sqlite')
    session = requests.Session()

    output = {}

    if args.collection_id is not None:

        logger.info("extracting information from Archive-It collection {}".format(args.collection_id))

        output['collection_id'] = args.collection_id

        # get the collection metadata
        aic = ArchiveItCollection(collection_id=args.collection_id, session=session)
        seed_uris = aic.list_seed_uris()

        output['collection_name'] = aic.get_collection_name()
        output['collected_by'] = aic.get_collectedby()

        urit_list = generate_archiveit_urits(args.collection_id, seed_uris)

        timemap_data, errors_data = process_timemaps_for_mementos(urit_list)

        # this really only makes sense for Archive-It collections
        logger.info("calculating number of mementos...")
        output["number_of_mementos"] = calculate_number_of_mementos(timemap_data)

        if args.growthcurve_filename is not None:
            logger.info("generating growth curve for collection {}".format(args.collection_id))
            mementos_list = parse_data_for_mementos_list(timemap_data)
            mdts_pct, urims_pct, urirs_pct = \
                convert_mementos_list_into_mdts_pct_urim_pct_and_urir_pct(
                mementos_list)

            draw_both_axes_pct_growth(
                args.collection_id,
                aic.get_collection_name(),
                aic.get_collectedby(),
                mdts_pct, urims_pct, urirs_pct,
                args.growthcurve_filename
            )

            output['auc_memento_curve'] = auc(mdts_pct, urims_pct) 
            output['auc_seed_curve'] = auc(mdts_pct, urirs_pct)
            output['auc_memento_minus_diag'] = output['auc_memento_curve'] - 0.5
            output['auc_seed_minus_diag'] = output['auc_seed_curve'] - 0.5
            output['auc_seed_minus_auc_memento'] = output['auc_seed_curve'] - output['auc_memento_curve']

        logger.info("calculating memento to seed ratio...")
        output['memento_seed_ratio'] = calculate_memento_seed_ratio(timemap_data)

        logger.info("calculating mementos per seed...")
        output['mementos_per_seed'] = calculate_mementos_per_seed(timemap_data)

        logger.info("calculating first memento datetime in collection...")
        output['first_memento_datetime'] = get_first_memento_datetime(timemap_data)

        logger.info("calculating last memento datetime in collection...")
        output['last_memento_datetime'] = get_last_memento_datetime(timemap_data)

        logger.info("calculating lifespan of collection...")
        output['lifespan_secs'] = (get_last_memento_datetime(timemap_data) - get_first_memento_datetime(timemap_data)).total_seconds()
        output['lifespan_mins'] = output['lifespan_secs'] / 60
        output['lifespan_hours'] = output['lifespan_secs'] / 60 / 60
        output['lifespan_days'] = output['lifespan_secs'] / 60 / 60 / 24
        output['lifespan_weeks'] = output['lifespan_secs'] / 60 / 60 / 24 / 7
        output['lifespan_years'] = output['lifespan_secs'] / 60 / 60 / 24 / 365

    elif args.file_listing_urims is not None:

        logger.info("reading list of URI-Ms from {}".format(args.file_listing_urims))
        
        seed_uris = []
        urit_list = []

        with open(args.file_listing_urims) as f:
            for line in f:
                line = line.strip()
                logger.info("reading line {}".format(line))

                r = session.get(line)

                try:
                    urir = r.links['original']['url']
                    logger.info("found urir {}".format(urir))
                except KeyError as e:
                    logger.exception("failed to find URI-R for URI-M {}".format(line))
                    continue

                if args.exclude_pattern is not None:
                    if args.exclude_pattern not in urir:
                        seed_uris.append( urir )
                        try:
                            urit_list.append( r.links['timemap']['url'] )
                        except KeyError as e:
                            logger.exception("failed to find the URI-T for URI-M {}".format(line))
                else:
                    seed_uris.append( urir )
                    try:
                        urit_list.append( r.links['timemap']['url'] )
                    except KeyError as e:
                        logger.exception("failed to find the URI-T for URI-M {}".format(line))

        timemap_data, errors_data = process_timemaps_for_mementos(urit_list)

        output['input_filename'] = args.file_listing_urims

    else:
        print("Please specify a collection ID with the -c argument or a file listing URI-Ms with the -i argument")
        sys.exit(255)

    logger.info("calculating statistics")

    logger.info("calculating number of seeds...")
    output["number_of_seeds"] = len(seed_uris)
    logger.info("found {} seeds".format(len(seed_uris)))

    logger.info("calculating diversity scores...")
    output['domain_diversity'] = calculate_domain_diversity(seed_uris)

    logger.info("calculating path depth diversity...")
    output['path_depth_diversity'] = calculate_path_depth_diversity(seed_uris)

    logger.info("calculating most frequent path depth...")
    output['most_frequent_path_depth'] = most_frequent_seed_uri_path_depth(seed_uris)

    logger.info("calculating the percentage of top-level URIs...")
    output['top_level_percentage'] = calculate_top_level_path_percentage(seed_uris)

    logger.info("calculating query string percentage...")
    output['query_string_percentage'] = calculate_percentage_querystring(seed_uris)

    with open(args.outputfile, 'w') as f:
        json.dump(output, f, indent=4, default=dtconverter)

    if args.growthcurve_filename is not None:
        logger.info("growth curve has been written to {}".format(args.growthcurve_filename))
        
    logger.info("Data has been written out to {}".format(args.outputfile))
    logger.info("Finished run")    
