#!/usr/bin/env python3

import csv, argparse, operator, sys, datetime, time, random, os.path, json, hppy as hy
from math import log10
#from scipy import stats
from hivclustering import *
from hivclustering.networkbuild import *

from datetime import date 

date_range = [1996,date.today().year]


def print_tns (clusters, network, print_level = None):
    print ("\t".join(['ID','Year','BaselineDegree','TNS','Outbound_1st','Undirected_1st','Outbound_after1st','Undirected_after1st','BaselineSequence','BaselineVLDate','BaselineVL']))
    
    for key, a_cluster in clusters.items():
        is_singleton = key is None
        for a_node in a_cluster:
            base_date = a_node.get_baseline_date (complete = True)
            if base_date.tm_year >= date_range[0] and base_date.tm_year <= date_range[1]:
                base_date_plus_one = datetime.date.timetuple(tm_to_datetime(base_date) + datetime.timedelta (days = 365))
                
                enrollment_edges = 0
                additional_edges = [0,0,0,0,0,0] # in, out, undirected
                if not is_singleton:
                    network.clear_filters()
                    network.apply_exact_date_filter (base_date)
                    distro_fit = network.fit_degree_distribution ()
                    stats      = network.get_edge_node_count ()
                    all_edges = network.get_all_edges_linking_to_a_node (a_node.id,ignore_visible=True,use_direction=False,reduce_edges=False)
                
                    for an_edge in all_edges:
                        if a_node.treatment_date:
                            if (an_edge.check_exact_date (a_node.treatment_date) == False): 
                                #print ("Edge treated")
                                continue
                
                        if an_edge.check_exact_date (base_date):
                            enrollment_edges += 1
                            #print ("Base", an_edge, a_node, base_date)
                        else:
                            #print ("Branch")
                            base_index = 0 if an_edge.check_exact_date (base_date_plus_one) else 3
                            direction = an_edge.compute_direction()
                            if direction is None:
                                index = base_index + 2
                            elif direction.id == a_node.id:
                                index = base_index + 1
                            else:
                                index = base_index + 0
                            additional_edges [index] += 1
                                    
                    node_tns   = tns (enrollment_edges, None, distro_fit['fitted'][distro_fit['Best']])
                else:
                    node_tns = 0
            
                vl = a_node.get_vl_by_date (tm_to_datetime(base_date))
                if vl is None:
                    vl = ['NULL','NULL']
                else:
                    vl = [vl[0].strftime("%Y%m%d"), str(vl[1])]
                base_seq = tm_to_datetime(base_date).strftime("%Y%m%d")
                     
                print ("\t".join([a_node.id,str(base_date.tm_year),str(enrollment_edges),str(node_tns),str(additional_edges[1]), str(additional_edges[2]),str(additional_edges[4]), str(additional_edges[5]), base_seq, vl[0], vl[1]]))
                if print_level is not None and node_tns is not None and node_tns >= print_level:
                    a_node.add_attribute('focus')
           


#-------------------------------------------------------------------------------		
def explore_tns_stability (network, outdeg, from_d, to_d, step):
    tns_by_d = {}
    node_names = None
    
    d = from_d
    
    while d <= to_d:
        fitted_degrees = {}
        print (d)
        print_network_evolution (network, fitted_degrees, outdeg, d, do_print = False)
        tns_by_d[d] = compute_tns (network, date_range[0], date_range[1], d, fitted_degrees, outdeg, do_print = False)
        if node_names is None:
            node_names = tns_by_d[d].keys()
            
        d += step
    
    distances = tns_by_d.keys()

    print ("Node,distance,TNS")
    
    for n in node_names:
        for d in distances:
            print (n,',',d,',',tns_by_d[d][n]['tns'])

#-------------------------------------------------------------------------------		

def import_edi (file):
	edi_by_id = {}
	ediReader = csv.reader(file)
	header = next(ediReader)
	if len (header) != 14:
		raise Exception ('Expected a .csv file with 14 columns as input')
	
	for line in ediReader:
		if len (line[1]): # has PID
			id = line[1].replace ('-','')
		else:
			id = line[0]
		
		geno_date = None
		if len (line[2]): # geno
			geno_date = time.strptime (line[2],'%m/%d/%Y')
		
		
		drug_date = None
		if len (line[4]): # drugz
			drug_date = time.strptime (line[4],'%m/%d/%Y')
			
		edi_date  = None
		stage = 'Chronic'
		
		if len (line [5]): # disease stage
			stage = line[5]

		if len (line [6]): # edi
			edi_date = time.strptime (line[6],'%m/%d/%Y')
			
		naive 	 = False
		if line[3] == 'ARV Naive':
			naive = True
			
		if geno_date and edi_date:
			if edi_date > geno_date:
				#print time.mktime(edi_date) - time.mktime(geno_date)
				
				part1 = time.strftime ("%m/%d",edi_date)
				part2 = time.strftime ("%Y",geno_date)
				new_edi_date = time.strptime ("/".join((part1,part2)),'%m/%d/%Y')
				#edi_date.tm_year = geno_date.tm_year
				if new_edi_date > geno_date:
					continue
				else:	
					edi_date = new_edi_date
				
		viral_load = None
		if len (line [8]): # vl
			viral_load = int (line[8])
			
		edi_by_id [id] = [geno_date, drug_date, stage, edi_date, viral_load, naive]
		#print (edi_by_id[id])
		#if (edi_date and drug_date and edi_date > drug_date):
		#	print "Fail %s" % id, edi_date, drug_date
		
	return edi_by_id

#-------------------------------------------------------------------------------		


def hex_grb_to_triplets (rgb):
    rgb3 = [0.,0.,0.]
    rgb3 [0] = int (rgb[0:2], 16)
    rgb3 [1] = int (rgb[2:4], 16)
    rgb3 [2] = int (rgb[4:6], 16)
    return rgb3
    

def interpolate_rgb_hex (fromC, toC, slider = 0.5):
    c1 = hex_grb_to_triplets (fromC)
    c2 = hex_grb_to_triplets (toC)
    
    c1 = [floor(c1[k] * (1-slider) + c2 [k] * (slider)) for k in range (3)] 
    
    return '#%02x%02x%02x' % (c1[0],c1[1],c1[2])

#-------------------------------------------------------------------------------		

def get_dot_string_vl (self, year_vis = None, tns = None):    
    if 'focus' in self.attributes:
        shape = 'doublecircle'
    else:
        shape = 'circle' 
    color = 'white'
    label = ''#time.strftime("%b %Y", self.get_baseline_date(True))

    
    edi_info = self.get_treatment_since_edi()
    
    is_mi = False
    '''for k in self.attributes:
        if k.find ('|CI') >= 0 or k.find ('|SI') >=0:
            #print ('MI!')
            is_mi = True
            shape = 'doubleoctagon'
            break
    '''
    
    
    #if self.get_sample_count() > 1:
    #    label = str(self.get_sample_count())
        


    color = 'white'
    
    if tns is not None:
        if self.id in tns:
            color = interpolate_rgb_hex ('FFFFFF','FF0000',tns[self.id])
    else:
        if self.get_vl() is not None:
            color = '/reds9/%d' % (round(log10 (self.get_vl())))
        
    if 'treated' in self.attributes:
        shape = 'triangle'
        color = 'grey'

    if 'index' in self.attributes:
        shape = 'diamond'
        color = 'green'
        
    if year_vis is not None:
        if self.get_baseline_date () > year_vis:
            return '"%s" [label = "%s", fillcolor = "%s", shape = %s, style = "invis"];\n' % (self.id, label , color, shape)
        
    return '"%s" [label = "%s", fillcolor = "%s", shape = %s];\n' % (self.id, label , color, shape)

#-------------------------------------------------------------------------------		

def generate_dot_tns (self, file, tns = None, year_vis = None, reduce_edges = True):
    file.write ('digraph G { overlap="prism1000"; sep = "+8"; \n outputorder = edgesfirst;\nnode[style=filled];\n');
    nodes_drawn = {}
    
    directed = {'undirected':0, 'directed':0}
    
    tns_vl = {}
    
    issues = set (['050109563','050116990','222-22c','222-37f'])
    
    print ("\t".join(['ID', 'BASELINE', 'EDI', 'ARV_DATE', 'EDGE_DATE', 'TREATMENT_SINCE_EDI', 'IS_AFTER_ART']))

    for edge in self.edges if reduce_edges == False else self.reduce_edge_set():
        if edge.visible:
            distance = self.edges[edge]
            
            p1 = network.nodes[edge.p1]
            p2 = network.nodes[edge.p2]
            
            edi1 = tm_to_datetime(p1.get_edi()) 
            edi2 = tm_to_datetime(p2.get_edi())
            
            #if (edge.p1.id in issues or edge.p2.id in issues):
            #    print (edge, edi1, edi2)
                
            transmission_date = None
            
            vl1 = None
            vl2 = None
            
            if edi1 is not None and edi2 is not None:
                if edi1 < edi2:
                   transmission_date = edi2
                else:
                   transmission_date = edi1
            else:
                if edge.date1 is not None and edge.date2 is not None:    
                    ed1 = tm_to_datetime (edge.date1)
                    ed2 = tm_to_datetime (edge.date2)
                    if ed2 > ed1:
                        transmission_date = ed2
                    else:
                        transmission_date = ed1
                        
            
            if transmission_date is not None:
                vl1 = p1.get_vl_by_date (transmission_date)
                vl2 = p2.get_vl_by_date (transmission_date)
                if vl1 is not None and abs (vl1[0]-transmission_date) > datetime.timedelta (365):
                    vl1 = None
                if vl2 is not None and abs (vl2[0]-transmission_date) > datetime.timedelta (365):
                    vl2 = None
                
                    
            if edge.p1 not in nodes_drawn:
                nodes_drawn[edge.p1] = edge.p1.get_baseline_date()
                file.write (edge.p1.get_dot_string(year_vis,tns))
            if edge.p2 not in nodes_drawn:
                nodes_drawn[edge.p2] = edge.p2.get_baseline_date()
                file.write (edge.p2.get_dot_string(year_vis,tns))
            
            if year_vis is not None:
                if edge.check_date (year_vis) == False:
                    file.write ('%s [style="invis" arrowhead = "%s"];\n' % (edge_attr[0], edge_attr[1]));
                    continue

            if isinstance(edge.compute_direction(),type(None)):
                directed ['undirected'] += 1
            else:
                directed ['directed'] += 1
                
            edge_attr = edge.direction()
            source_p = edge.compute_direction()

            if source_p is not None:
                source_d = tm_to_datetime(source_p.get_baseline_date(True))
                print ("\t".join([source_p.id, source_d.strftime("%Y-%m-%d"), tm_to_datetime(source_p.get_edi()).strftime ("%Y-%m-%d") if source_p.get_edi() else 'NA',  tm_to_datetime(source_p.treatment_date).strftime ("%Y-%m-%d") if source_p.treatment_date else 'NA', transmission_date.strftime ("%Y-%m-%d"), str(source_p.get_treatment_since_edi().days) if source_p.get_treatment_since_edi() is not None else 'NA', str(tm_to_datetime(source_p.treatment_date)<transmission_date) if source_p.treatment_date else 'NA']))

            if source_p is not None:
                if source_p == p1:
                    source_vl = vl1
                else:
                    source_vl = vl2
                    
            else:
                source_vl = None
                
                
                
            if source_p and source_vl and tns and transmission_date :
                base_date = source_p.get_baseline_date (complete = True)
                if transmission_date - tm_to_datetime(base_date) <= datetime.timedelta (365): 
                    if source_p.id not in tns_vl:
                        tns_vl [source_p.id] = {'TNS': tns[source_p.id] if source_p.id in tns else 0.0,
                                                'VL' : [],
                                                'TD' :  transmission_date.strftime("%Y%m%d"),
                                                'VD' : []
                                                }
                                            
                    tns_vl [source_p.id]['VL'].append (source_vl[1])
                    tns_vl [source_p.id]['VD'].append (source_vl[0])
            
             
            pw = 2 
            if source_vl is not None:  
                color = interpolate_rgb_hex ('FFFFFF','0000FF',min (1, log10 (source_vl[1])/6.))
                pw = 5
            else:
                color = 'red' if edge.has_attribute ('UDS') else 'grey'   
            file.write ('%s [style="bold" color = "%s" label = "%s" arrowhead = "%s" penwidth = "%d"];\n' % (edge_attr[0], color, edge.label(), edge_attr[1], pw));

    file.write ("\n}")
    print ("ID,TNS,VL,Transmission,ViralLoad")
    for i,v in tns_vl.items():
        print (", ".join([str(k) for k in [i, v['TNS'], sum (v['VL'])/len (v['VL']),v['TD'],"|".join([k.strftime('%Y%m%d') for k in v['VD']])]]))
    return directed    
 
#-------------------------------------------------------------------------------		

def print_network_evolution (network, store_fitted = None, outdegree = False, distance = None, do_print = True):
    byYear = []
    
    for year in range (2000,2013):
        network.clear_filters ()
        network.apply_date_filter (year, do_clear= True)
        if distance is not None:
           network.apply_distance_filter (distance, do_clear = False)
        network_stats = network.get_edge_node_count ()
        network.compute_clusters()
        clusters = network.retrieve_clusters()
        if outdegree:
            distro_fit = network.fit_degree_distribution ('outdegree')        
        else:
            distro_fit = network.fit_degree_distribution ()
        #print ("Best distribution is '%s' with rho = %g" % (distro_fit['Best'], 0.0 if distro_fit['rho'][distro_fit['Best']] is None else  distro_fit['rho'][distro_fit['Best']]), distro_fit['degrees'])
        if store_fitted is not None:
            store_fitted [year] = distro_fit['fitted']['Waring']
        byYear.append ([year, network_stats['nodes'], network_stats['edges'], network_stats['total_sequences'],len(clusters),max ([len (clusters[c]) for c in clusters if c is not None]),distro_fit['rho']['Waring']])
    
    #print (distro_fit)
    
    if do_print:
        print ("\nYear,Nodes,Edges,Sequences,Clusters,MaxCluster,rho");
        for row in byYear:
            print (','.join([str (k) for k in row]))
 
#-------------------------------------------------------------------------------		
def print_degree_distro (network):
    print ("\t".join (['degree','rawcount','rawpred','count','pred','ccount','cpred']))
    total = float(sum(distro_fit['degrees']));
    total1 = 0.
    total2 = 0.
    for k in range (0, len(distro_fit['degrees'])): 
        vec = [str(p) for p in [k+1,distro_fit['degrees'][k],distro_fit['fitted']['Waring'][k]*total,distro_fit['degrees'][k]/total,distro_fit['fitted']['Waring'][k]]]
        vec.extend ([0.,0.])
        total1 += distro_fit['degrees'][k]/total
        total2 += distro_fit['fitted']['Waring'][k]
        vec[5] = str(total1)
        vec[6] = str(total2)
        print ("\t".join (vec))
    
    for dname,rho in distro_fit['rho'].items():
        print ("%s : rho = %s, BIC = %s, p = %s" % (dname, 'N/A' if rho is None else "%5.2f" % (rho) , 'N/A' if distro_fit["BIC"][dname] is None else "%7.2f" % distro_fit["BIC"][dname], 'N/A' if distro_fit["p"][dname] is None else "%4.2f" % (distro_fit["p"][dname])))

#-------------------------------------------------------------------------------		
def print_degree_table (network, year_from, year_to, distance = 0.015):
    by_id = {}
    
    sep = '\t'
    
    for node in network.nodes:
        by_id [node.id] = []
    
    for year in range (year_from,year_to):
        network.clear_filters ()
        network.apply_distance_filter (distance)
        network.apply_date_filter (year,False)
        degree_list = network.get_node_degree_list(year,True)
        network.clear_adjacency()
        network.apply_date_filter (year,False)
        cluster_size = network.cluster_size_by_node()

        for k in degree_list: 
            if degree_list [k] is None:
                by_id [k.id].append('.')
            else:
                by_id [k.id].append(str (degree_list [k]) + '|' + str (cluster_size[k]))
    
    print  ('ID%s%s' % (sep,sep.join([str (y) for y in  range (year_from,year_to)])))
    
    for node in by_id:
        print ('%s%s%s'%(node,sep,sep.join (by_id[node])))   
        
#-------------------------------------------------------------------------------		

def tns (degree, cluster_size, fitted = None):
    if fitted is not None:
        return sum(fitted[0:degree])
    else:
        return None
        #raise RuntimeError ('Missing fitted in tns')
        #return float(degree+1)**(1./3) + (log10 (cluster_size) if cluster_size is not None else 0.)
        
#-------------------------------------------------------------------------------		
def compute_tns (network, year_from, year_to, distance = 0.015, fitted = None, outdegree = False, do_print = True):
    by_id = {}

    for year in range (year_from,year_to):
        network.clear_filters ()
        network.apply_distance_filter (distance)
        network.apply_date_filter (year,False)
            
        degree_list = network.get_node_degree_list(year,True)
        
        network.clear_adjacency()
        network.apply_date_filter (year,False)
        cluster_size = network.cluster_size_by_node()
        
        degree_index = 1 if outdegree else 3
        
        #if year < year_to-1:
        if True:
            for node in degree_list:
                if node.get_baseline_date () == year and year < year_to - 1:
                    by_id [node.id] = {}
                    tns_s = tns (degree_list[node][degree_index],cluster_size[node], None if fitted is None else fitted[year])
                    by_id[node.id]['tns'] = tns_s
                    by_id[node.id]['deg'] = degree_list[node][degree_index]
                    by_id[node.id]['cls'] = cluster_size[node]
                    by_id[node.id]['year'] = float (year)
                else:
                    if year > year_from and node.get_baseline_date () == year - 1:
                        if cluster_size[node] - by_id[node.id]['cls'] > 0:
                            by_id[node.id]['cls'] = (degree_list[node][degree_index] - by_id[node.id]['deg'])/(cluster_size[node] - by_id[node.id]['cls'])
                        else:
                            by_id[node.id]['cls'] = 0
                        by_id[node.id]['deg'] = degree_list[node][degree_index] - by_id[node.id]['deg']
                         
                        
                        
        else:
            for node in degree_list:
                if (node.id in by_id):
                     #print (by_id[node.id])
                     by_id[node.id]['deg'] = (degree_list[node][3] - by_id[node.id]['deg'])/(year_to-by_id[node.id]['year'])
        
         
    if do_print:     
        print ("id,TNS,deg,cls")
        for k in by_id:
            print (k, ',', by_id[k]['tns'], ',', by_id[k]['deg'], ',', by_id[k]['cls'])
            '''if (by_id[k]['cls'] > 1):
                print (k, file = sys.stderr)
            '''
    return by_id
               
#-------------------------------------------------------------------------------		

def main(): 
    network = build_a_network ()

    network.generate_dot = generate_dot_tns
    patient.get_dot_string = get_dot_string_vl

    distro_fit = describe_network (network)


    '''
    early_t = network.get_all_treated_within_range(datetime.timedelta(days = 7*4))
    print ('\t'.join (['ID','EDI','ART','DIFF']))
    for n in early_t:
        print ("\t".join([n.id, tm_to_datetime(n.get_edi()).strftime ("%Y-%m-%d"), tm_to_datetime(n.treatment_date).strftime ("%Y-%m-%d"), str(n.get_treatment_since_edi().days)]))
        n.add_attribute ('eART')
    
    e_cnt = len (early_t)

    for k,v in network.get_node_degree_list(attribute_selector = 'eART',do_direction = True).items():
        print (k,v[0]+v[1])

    obs = e_cnt + sum([k[0] + k[1] for k in list(network.get_node_degree_list(attribute_selector = 'eART',do_direction = True).values())])    
    print (obs)
    p  = 0

    sim_degrees = []

    for k in range (1000):
        network.randomize_attribute ('eART', None)
        sim_degrees.append (e_cnt + sum([k[0] + k[1] for k in list(network.get_node_degree_list(attribute_selector = 'eART',do_direction = True).values())]))
        if sim_degrees[-1] <= obs:
            p += 1
     
    
    print ("EARLY ART degrees:" , describe_vector(sim_degrees), p/(k+1.), e_cnt)
    sys.exit(0)
    '''

    #print_tns (network, 0.8)
    
    '''
    

    with open ('data/TNS_new.txt', 'r') as tns_fh:
        tns_reader = csv.reader (tns_fh,delimiter = '\t')
        next(tns_reader)
        for k in tns_reader:
            precomputed_tns [k[0]] = float (k[3])
    '''
        
    if settings().threshold is None: 
        print ("%d nodes with multiple dates" % len (network_stats['multiple_dates']))
        minfo = network.report_multiple_samples (network_stats['multiple_dates'])
        print ("Samples per patient: %d-%d (median %d)" % (minfo['samples']['min'],minfo['samples']['max'],minfo['samples']['median']))
        print ("Duration per patient (wks): %d-%d (median %d)" % (minfo['followup']['min']/7,minfo['followup']['max']/7,minfo['followup']['median']/7))
    
        earliest = None
        latest   = None
    
        for k in network.nodes:
            first  = k.get_baseline_date(True)
            lastd = k.get_latest_date (True)
        
            if earliest is None or first < earliest:
                earliest = first
            if latest is None or lastd > latest:
                latest = lastd
    
        print ("%d/%d/%d -- %d/%d/%d" % (earliest.tm_year, earliest.tm_mon, earliest.tm_mday, latest.tm_year, latest.tm_mon, latest.tm_mday))
    
        with open ('baseline.csv', 'w') as fh:
            network.spool_pairwise_distances (fh, baseline = True)
    
        sys.exit(0)

    #network.drop_singleton_nodes ()
    network.compute_clusters()
    clusters = network.retrieve_clusters()
    print ("Found %d clusters" % (len(clusters) - (1 if None in clusters else 0)), file = sys.stderr)
    print ("Maximum cluster size = %d nodes" % max ([len (clusters[c]) for c in clusters if c is not None]), file = sys.stderr)

    print ("Fitting the degree distribution to various densities", file = sys.stderr)
    #print (network.get_degree_distribution())
    distro_fit = network.fit_degree_distribution ()
    print ("Best distribution is '%s' with rho = %g" % (distro_fit['Best'], distro_fit['rho'][distro_fit['Best']]))

    print_tns (clusters, network, 0.8)
    #network.apply_date_filter( date_range[0], True, True)
    #network.apply_date_filter( date_range[1], False, False)
    #network.generate_dot (network, settings().dot, tns = precomputed_tns)

if __name__=='__main__':
    main()
