#!/bin/env python

"""
Analysis of CDR3 string samples
"""

import os
import click 
import networkx as nx
import subprocess
import re
import time
import csv 
import pandas as pd
import numpy as np
from glob import glob
import warnings

import imnet

try: 
    import findspark
    findspark.init()
    import pyspark
    SPARK_FOUND = True
except (ImportError, ValueError): 
    warnings.warn('No spark libraries found - make sure SPARK_HOME is set\nProceeding in single-core mode', RuntimeWarning)
    SPARK_FOUND = False

@click.group()
@click.option('--spark-config', default='./spark_config', envvar='SPARK_CONFIG_DIR', help='Spark configuration directory')
@click.option('--spark-master', default='local[4]', help='Spark master')
@click.option('--kind', default='graph', type=click.Choice(['graph','degrees','all']), help='Which kind of output to produce')
@click.option('--outdir', default='./', help='Output directory')
@click.option('--min-ld', default=1, help='Minimum Levenshtein distance')
@click.option('--max-ld', default=1, help='Maximum Levenshtein distance')
@click.option('--sc-cutoff', default=1000, help='For number of strings below this, Spark will not be used')
@click.option('--spark/--no-spark', default=True, help="Whether to use Spark or not")
@click.pass_context
def analyze(ctx, spark_config, spark_master, kind, outdir, min_ld, max_ld, sc_cutoff, spark):
    # set up spark
    os.environ['SPARK_CONF_DIR'] = os.path.realpath(spark_config)
    ctx.obj['spark_master'] = spark_master

    # runtime options
    ctx.obj['kind'] = kind
    ctx.obj['outdir'] = outdir
    ctx.obj['min_ld'] = min_ld
    ctx.obj['max_ld'] = max_ld
    ctx.obj['sc_cutoff'] = sc_cutoff
    ctx.obj['spark'] = spark & SPARK_FOUND

@analyze.command()
@click.option('--nstrings', default=100, help="Number of strings to generate")
@click.option('--min-length', default=4, help="minimum number of characters per string")
@click.option('--max-length', default=20, help="maximum number of characters per string")
@click.pass_context
def random(ctx, nstrings, min_length, max_length):
    """Run analysis on a randomly generated set of strings for testing"""

    strings = imnet.random_strings.generate_random_sequences(nstrings)

    sc = initialize_SparkContext(nstrings, ctx)
    kind, outdir, min_ld, max_ld = [ctx.obj[i] for i in ['kind','outdir','min_ld','max_ld']]

    if kind == 'all': kind = ['degrees', 'graph']
    else: kind = [kind]

    outfile = os.path.join(ctx.obj['outdir'], os.path.basename(filename))

    for k in kind:
        write_output(strings, k, outfile, min_ld, max_ld, sc)
    
    if sc is not None: sc.stop()

@analyze.command()
@click.option('--nstrings-min', default=100, help="Minimum number of strings to generate")
@click.option('--nstrings-max', default=100, help="Maximum number of strings to generate")
@click.option('--num-runs', default=1, help="Number of runs between nstrings-min and nstrings-max")
@click.option('--benchmark-file', default='imnet_analyze_benchmark.csv', help="File to store benchmark results")
@click.option('--append/--no-append', default=True, help="Append to benchmark results")
@click.option('--ncores', default=1, help="Number of cores used")
@click.pass_context
def benchmark(ctx, nstrings_min, nstrings_max, num_runs, benchmark_file, append, ncores):
    """Run a series of benchmarks for graph and degree calculations and store the results in a file."""

    # set up output file
    fieldnames = ['nstrings', 'ncores', 'spark', 'type', 'min_ld', 'max_ld', 'dt']

    if os.path.exists(benchmark_file) and append:
        f = open(benchmark_file, 'a')
        writer = csv.DictWriter(f, fieldnames=fieldnames)
    else:
        f = open(benchmark_file, 'w')
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()

    # run the benchmarks
    for nstrings in np.logspace(np.log10(nstrings_min),np.log10(nstrings_max),num_runs):
        nstrings = int(nstrings)

        strings = imnet.random_strings.generate_random_sequences(nstrings)

        sc = initialize_SparkContext(nstrings, ctx)

        # wait for the SparkContext to initialize
        if sc is not None: 
            while(sc.defaultParallelism < ncores-1):
                print 'waiting', sc.defaultParallelism, ncores
                time.sleep(1)

        kind, min_ld, max_ld = [ctx.obj[i] for i in ['kind','min_ld','max_ld']]

        if kind == 'graph' or kind == 'all':
            timein = time.time()
            mat = imnet.process_strings.distance_matrix(strings, min_ld=min_ld, max_ld=max_ld, sc=sc)
            if sc is not None: 
                # if using spark, make sure the calculation is actually carried out
                mat.count()
            else : 
                x = list(mat)
            timeout = time.time()
            writer.writerow({'nstrings': nstrings, 'spark': ctx.obj['spark'], 'ncores': ncores, 'dt': timeout-timein, 'type': 'graph',
                             'min_ld': min_ld, 'max_ld':max_ld})
            
        if kind == 'degrees' or kind == 'all':
            timein = time.time()
            ds = imnet.process_strings.generate_degrees(strings, min_ld, max_ld, sc)
            timeout = time.time()
            writer.writerow({'nstrings': nstrings, 'spark': ctx.obj['spark'], 'ncores': ncores, 'dt': timeout-timein, 'type': 'degrees',
                             'min_ld': min_ld, 'max_ld':max_ld})

        if sc is not None: sc.stop()

    f.close()
    
@analyze.command()
@click.argument('input', type=click.File(mode='r'))
@click.option('--string-loc', default=None, help="Column that contains the string in the files")
@click.pass_context
def file(ctx, input, string_loc):
    """Process an individual file with CDR3 strings, one per line, and save gml or degree data to disk."""

    nstrings = count_lines(input)
    input.seek(0)

    sc = initialize_SparkContext(nstrings, ctx)

    imnet.process_file(input, sc=sc, string_loc=string_loc, **ctx.obj)
    
    if sc is not None: sc.stop()

@analyze.command()
@click.argument('input-dir', type=click.Path(exists=True))
@click.option('--string-loc', default=None, help="Column that contains the string in the files")
@click.pass_context 
def directory(ctx, input_dir, string_loc):
    """Process a directory of CDR3 string files"""

    files = glob(os.path.join(input_dir,'*'))
    
    sc = initialize_SparkContext(1e10, ctx)
    
    if not os.path.exists(ctx.obj['outdir']):
        raise IOError('Directory %s not found'%ou)

    for filename in files:
        sc_val = sc if count_lines(filename) > ctx.obj['sc_cutoff'] else None
        imnet.process_file(filename, sc=sc_val, string_loc=string_loc, **ctx.obj)

    if sc is not None: sc.stop()

##################
# Helper functions
##################
def initialize_SparkContext(nstrings, ctx):
    """
    Helper function to start the spark context and make a list of random strings

    Inputs
    ------
    nstrings : int
        number of strings to generate
    ctx : click.Context
        a click context for the current task group
    """
    if ctx.obj['spark'] and (nstrings > ctx.obj['sc_cutoff']):
        # start the spark context
        return pyspark.SparkContext(master=ctx.obj['spark_master'])
    else:
        return None

def count_lines(input):
    closefile = False
    if isinstance(input,str): 
        closefile = True
        f = open(input, 'r')
    else: 
        f = input
    for i, l in enumerate(f):
        pass
    if closefile: f.close()
    return i + 1

if __name__ == "__main__":
    analyze(obj={})
