#!/usr/bin/env python

import sys, os
import re
from collections import defaultdict
import numpy as np
from subprocess import *
import argparse

known_ports = {
    '42424' : 'nimbus',
    '42426' : 'background'
}

flatten = lambda l: [item for sublist in l for item in sublist]

parser = argparse.ArgumentParser(description="plot throughput and delay from a mahimahi trace")
parser.add_argument('log', type=str, help="path to mahimahi trace")
parser.add_argument('ms_per_bin', type=int, help="granularity of x-axis in plot")


parser.add_argument('--title',type=str,help="plot title")
parser.add_argument('--key',help="add key to plot", action='store_true')
parser.add_argument('--xtics',type=float, help="interval between x values on plot")
parser.add_argument('--line-width',type=int)
parser.add_argument('--plot-width',type=int)
parser.add_argument('--font-size',type=int)

parser.add_argument('--xrange',type=str,help="range of time values to plot, in seconds, \"min:max\", default=*:*")
parser.add_argument('--yrange',type=str,help="range of throughput to plot, in Mbps, \"min:max\", default=*:*")
parser.add_argument('--y2range',type=str,help="range of delay values to plot, in ms, \"min:max\", default=*:*")

parser.add_argument('--no-display',help="do not automatically open plot, just save it", action='store_true')
parser.add_argument('--no-sum',help="do not plot sum throughput", action='store_true')
parser.add_argument('--no-delay',help="do not plot delay", action='store_true')
parser.add_argument('--no-grid',help="do not display grid on plot", action='store_true')
parser.add_argument('--fake',default=False,help="output gnuplot script and parsed data without plotting",action='store_true')
parser.add_argument('--no-port', help="original version of mahimahi without port information",action='store_true')

parser.add_argument('--agg',type=str,help="path to list of port ranges and flow names to aggregate in plot")
parser.add_argument('--link-dir',type=str,help="direction (up|down) of link log")
parser.add_argument('--bg',type=str,help="path to file specifying background traffic pattern")
parser.add_argument('--plot-expected',help="plot the expected throughput based on the background traffic (requires --bg)",action='store_true')
parser.add_argument('--delay-f',type=str,help="how to calculate delay over each interval: (min|max|avg|X%%) where X is %%tile, e.g. use --delay-f 50%% for median",default="50%")

parser.add_argument('--nimbus',type=str,help='path to nimbus output file, adds red boxes for incorrect mode to plot')
parser.add_argument('--ports',type=str,help="comma-separated list of ports to identify flows. denote portranges with colons, e.g. 1000:1002 is equivalent to 1000,1001,1002")

args = parser.parse_args()

def parse_ports(s):
    tmp = s.split(",")
    ports = []
    for i in range(len(tmp)):
        if ":" in tmp[i]:
            f,l = tmp[i].split(":")
            ports += range(int(f),int(l)+1)
        else:
            ports.append(int(tmp[i]))
    return map(lambda x : str(x), ports)

log = args.log
ms_per_bin = args.ms_per_bin
ports_to_track = []
agg_ports = {}
agg_names = []
if args.agg is not None and args.ports is not None:
    print "ERROR: can only specify aggregate list OR port list"
    parser.print_help()
    sys.exit(-1)

if args.agg is not None:
    if not os.path.exists(args.agg):
        print "ERROR: could not find aggregate flow list: " + args.agg
        parser.print_help()
        sys.exit(-1)
    with open(args.agg) as f:
        try:
            for l in f:
                p, name = l.strip().split("\t")
                ports = parse_ports(p)
                for port in ports:
                    agg_ports[port] = name
                    if not name in agg_names:
                        agg_names.append(name)
        except:
            print "ERROR: error parsing aggregate flow list. format is: [portlist]\\t[key name]"
            parser.print_help()
            sys.exit(-1)

if args.ports is not None:
    if args.ports == "all":
        ports_to_track = "all"
    else:
        try:
            ports_to_track = parse_ports(args.ports)
        except:
            print "ERROR: error parsing port list."
            parser.print_help()
            sys.exit(-1)

uplink = None
if args.link_dir is not None:
    if 'up' in args.link_dir:
        uplink = True
    elif 'down' in args.link_dir:
        uplink = False
    else:
        print "ERROR: unknown link direction type"
        parser.print_help()
        sys.exit(-1)


if args.bg is not None:
    if not os.path.exists(args.bg):
        print "ERROR: could not find bg file", args.bg
        parser.print_help()
        sys.exit(-1)
else:
    if args.plot_expected or args.nimbus is not None:
        print "ERROR: requires --bg FILE"
        parser.print_help()
        sys.exit(-1)

if args.nimbus is not None:
    if not os.path.exists(args.nimbus):
        print "ERROR: could not find nimbus file", args.nimbus
        parser.print_help()
        sys.exit(-1)

def find_agg_name(flow):
    src,dst = flow.split(":")
    if uplink:
        if dst in agg_ports:
            return agg_ports[dst]
    else:
        if src in agg_ports:
            return agg_ports[src]
    return None

def ms_to_bin(ms):
    return ms / ms_per_bin

def bin_to_seconds(b):
    return "%.3f" % (b * ms_per_bin / 1000.0)

def bits_to_mbps(bits, duration=(ms_per_bin / 1000.0)):
    return bits / duration / 1000000.0
###

capacity = defaultdict(int)
arrivals = defaultdict(lambda : defaultdict(int))
departures = defaultdict(lambda : defaultdict(int))
delays = defaultdict(list)
all_delays = []

first_t, last_t, base_t = None, None, None
capacity_sum, arrival_sum, departure_sum = 0, defaultdict(int), defaultdict(int)

xmin,xmax = None,None
if args.xrange:
    xmin,xmax = args.xrange.split(":")
    if xmin != "*": 
        xmin = float(xmin) * 1000.0
    if xmax != "*":
        xmax = float(xmax) * 1000.0

### parse log file
header = True
with open(log) as f:
    for l in f:
        if header:
            m = re.search(r"^# base timestamp: (\d+)", l)
            if m:
                base_t = int(m.groups()[0])
                continue
            elif l[0] == "#":
                continue
            else:
                header = False

        sp = l.strip().split(" ")
        t, etype, num_bytes = sp[0:3]

        t = int(t)
        t -= base_t
        if (xmin and t < xmin) or (xmax and t > xmax):
            continue

        tbin = ms_to_bin(t)

        if not last_t:
            first_t = t
            last_t = t
        last_t = max(t, last_t)

        num_bytes = int(num_bytes)
        num_bits = num_bytes * 8

        if etype == "+":
            if args.no_port is not None and not args.no_port:
                flow = sp[3]
                agg_name = find_agg_name(flow)
                arrivals[tbin][flow] += num_bits
                arrival_sum[flow] += num_bits
                if agg_name:
                    arrivals[tbin][agg_name] += num_bits
                    arrival_sum[agg_name] += num_bits
            arrivals[tbin]['sum'] += num_bits
            arrival_sum['sum'] += num_bits
        elif etype == "-":
            if args.no_port is not None and not args.no_port:
                flow = sp[3]
                agg_name = find_agg_name(flow)
                departures[tbin][flow] += num_bits
                departure_sum[flow] += num_bits
                if agg_name:
                    departures[tbin][agg_name] += num_bits
                    departure_sum[agg_name] += num_bits
                try:
                    delay = int(sp[4])
                except:
                    sys.exit("invalid departure format, expected: \"[t] - [num_bytes] [src:dst] [delay]\", got: \"%s\"" % l.strip())
            else:
                try:
                    delay = int(sp[3])
                except:
                    sys.exit("invalid departure format, expected: \"[t] - [num_bytes] [delay]\", got: \"%s\"" % l.strip())

            departures[tbin]['sum'] += num_bits
            departure_sum['sum'] += num_bits

            delays[tbin].append(delay)
            all_delays.append(delay)
            
        elif etype == "#":
            capacity[tbin] += num_bits
            capacity_sum += num_bits
        else:
            sys.exit("unrecognized event type: %s" % etype)
###

if not first_t:
    sys.exit("must have at least one event")
if len(all_delays) <= 0:
    sys.exit("must have at least one departure event")


arr_flows = flatten(map(lambda x : x.keys(), arrivals.values()))
dep_flows = flatten(map(lambda x : x.keys(), departures.values()))
all_observed_flows = set(arr_flows + dep_flows)
flows_to_track = []
flow_to_name = {}
if ports_to_track == "all":
    for flow in all_observed_flows:
        if flow == "sum" or flow == "0:0":
            continue
        if not flow in flows_to_track:
            flows_to_track.append(flow)
        src,dst = flow.split(":")
        flow_to_name[flow] = flow
else:
    for flow in all_observed_flows:
        for port in ports_to_track:
            if port in flow:
                if not flow in flows_to_track:
                    flows_to_track.append(flow)
                src,dst = flow.split(":")
                port = flow
                if src in ports_to_track:
                    if not dst in ports_to_track:
                        port = src
                else:
                    port = dst
                name = known_ports[port] if port in known_ports else port
                flow_to_name[flow] = name
                break

### print statistics
duration = (last_t - first_t) / 1000.0
if args.xrange:
    xmin,xmax = args.xrange.split(":")
    if xmin == "*":
        xmin = first_t
    else:
        xmin = float(xmin) * 1000.0
    if xmax == "*":
        xmax = last_t
    else:
        xmax = float(xmax) * 1000.0
    duration = (xmax - xmin) / 1000.0

xmin_s = int(xmin / 1000.0) if xmin else 0
xmax_s = int(xmax / 1000.0) if xmax else int(last_t / 1000.0)

avg_capacity = (capacity_sum / duration) / 1000000.0
avg_ingress = (arrival_sum['sum'] / duration) / 1000000.0
avg_thru = (departure_sum['sum'] / duration) / 1000000.0

all_delays.sort()
ppavg = np.mean(all_delays)
pp50 = np.percentile(all_delays, 50)
pp95 = np.percentile(all_delays, 95)

sys.stderr.write("duration: %.2f seconds\n" % duration)
sys.stderr.write("average capacity: %.2f Mbit/s\n" % avg_capacity)
sys.stderr.write("average ingress: %.2f Mbit/s\n" % avg_ingress)
sys.stderr.write("average throughput: %.2f Mbit/s (%.1f%% utilization)\n" % 
        (avg_thru, 100.0 * (avg_thru / avg_capacity)))
sys.stderr.write("per-packet queueing delay: avg/median/95th = %.0f/%.0f/%.0f ms \n" % (ppavg, pp50, pp95))
###

if flows_to_track:
    sys.stderr.write("per-flow throughput:\n")
    for flow in flows_to_track:
        thru = bits_to_mbps(departure_sum[flow], duration)
        util = 100.0 * (thru / avg_capacity)
        sys.stderr.write("\t%s %.2f Mbit/s %.1f%%\n" % (flow_to_name[flow], thru, util))
if agg_names:
    sys.stderr.write("per-group aggregate throughput:\n")
    for agg in agg_names: # TODO sorted?
        thru = bits_to_mbps(departure_sum[agg], duration)
        util = 100.0 * (thru / avg_capacity)
        sys.stderr.write("\t%s %.2f Mbit/s %.1f%%\n" % (agg.split(" ")[0], thru, util))

### compile data for gnuplot
keys = [arrivals.keys(),departures.keys(),capacity.keys()]
first_bin = min(map(lambda x : min(x), keys))
last_bin = max(map(lambda x : max(x), keys))
if first_bin == last_bin:
    sys.exit("ms_per_bin=%d is too short for %.2f second trace" % (ms_per_bin, duration))

if args.delay_f == "min":
    delay_f = (lambda ds : min(ds))
elif args.delay_f == "max":
    delay_f = (lambda ds : max(ds))
elif args.delay_f == "avg":
    delay_f = (lambda ds : numpy.mean(ds))
elif "%" in args.delay_f:
    ptile = int(args.delay_f.split('%')[0])
    delay_f = (lambda ds : np.percentile(ds,ptile))

TMP_FILE = '/tmp/mm-graph.tmp'

if args.fake:
    f = sys.stdout
else:
    f = open(TMP_FILE, 'w')

current_buf_bytes = 0
for tbin in range(first_bin, last_bin+1):
    t = bin_to_seconds( tbin )

    dep_t = bits_to_mbps(departures[tbin]['sum']) if tbin in departures else 0
    del_t = delay_f(delays[tbin]) if tbin in delays else 0

    f.write("{} {} {}".format(t, dep_t, del_t))

    if flows_to_track:
        for flow in flows_to_track:
            f.write(" {}".format(bits_to_mbps(departures[tbin][flow]) if tbin in departures else 0))
    if agg_names:
        for agg in agg_names:
            f.write(" {}".format(bits_to_mbps(departures[tbin][agg]) if tbin in departures else 0))
    f.write("\n")
    #cap_t = bits_to_mbps(capacity[tbin]) if tbin in capacity else 0
    #arr_t = bits_to_mbps(arrivals[tbin]['sum']) if tbin in arrivals else 0
    #current_buf_bytes += (arrivals[tbin]['sum'] if tbin in arrivals else 0)
    #current_buf_bytes -= (departures[tbin]['sum'] if tbin in departures else 0)
    #print t, cap_t, arr_t, dep_t, current_buf_bytes

if not args.fake:
    f.close()


gnuplot = None
outf = None
if args.fake is not None and args.fake:
    gnuplot = Popen('cat', stdin=PIPE, stdout=subprocess.STDERR)
else:
    outfbase = ".".join(log.split(".")[:-1]) 
    outfname = outfbase + ".eps"
    outf = open(outfname, 'w')
    gnuplot = Popen('gnuplot', stdin=PIPE, stdout=outf)

ranges = ["*:*","*:*","*:*"]
for i,d in enumerate([args.xrange,args.yrange,args.y2range]):
    if d is not None:
        if i == 0:
            ranges[i] = "0:"+str(int(d.split(":")[1])-xmin_s)
        else:
            ranges[i] = d
    else:
        if i == 1:
            ranges[i] = "0:"+str(int(avg_capacity)*2)
            #ranges[i] = "0:"+str(int(avg_thru)*2)
        elif i == 2:
            ranges[i] = "0:"+str(int(pp95 + 10))
            #ranges[i] = "-{m}:{m}".format(m=max(all_delays))


width = 2
lw = 5
if args.xtics is not None:
    width = (duration / args.xtics) / 20.0
    lw = int((width / 2.0) * 5.0)
else:
    args.xtics = (duration / 20.0 / 2)

if args.plot_width is not None:
    width = args.plot_width
    lw = int((width / 2.0) * 5.0)

if args.line_width is not None:
    lw = args.line_width

title = log.split(".log")[0]
if args.title is not None:
    title = args.title
title += " / avg egress: {:.2f} Mbit/s ({:.1f}%) / median per-pkt delay: {:.0f} ms ".format(
        avg_thru, 
        100.0 * (avg_thru / avg_capacity),
        pp50
)

gnuplot.stdin.write("""
set terminal postscript enhanced color eps font "Helvetica" {font_size}
set size {width},1

{key}

set xtics 0,{xtics},{duration}

set xrange [{ranges[0]}]
set yrange [{ranges[1]}]
set y2range [{ranges[2]}]

set ytics nomirror tc lt 1
set y2tics nomirror tc lt 3

set xlabel "Time (seconds)"
set ylabel "Throughput (Mbit/s)" tc lt 1
set y2label "Per-Pkt Queueing Delay (ms)" tc lt 3

{no_grid}set grid

set style rect fc lt 7 behind fs solid 0.15 noborder

set title '{title}'
""".format(
    title=title,
    duration=int(duration),
    xtics=int(round(args.xtics)),
    ranges=ranges,
    width=width,
    lw=lw,
    no_grid='un' if args.no_grid else '',
    key="set key outside top center horizontal" if args.key else "unset key",
    font_size=(args.font_size) if args.font_size else 20
))

if args.bg:
    time_to_mode = {}
    try:
        with open(args.bg) as f:
            ymax = int(avg_capacity) * 2
            r = f.readlines()
            for i in range(len(r)):
                l = r[i].strip().split(" ")
                if "=" in l[1]:
                    gnuplot.stdin.write("set arrow from {t},0 to {t},{ymax} nohead lc rgb \"black\" lw 5\n".format(
                        t=l[0],
                        ymax=ymax
		    ))
                    continue
                if "X" in l[1]:
                    break

                t,label,expected,mode = l
                t = int(t)
                next_t = int(r[i+1].split(" ")[0])
                
                gnuplot.stdin.write("set label \"{label}\" at {x},{y}\n".format(
                    label=label,
                    x=(t + ((next_t - t) / 2) - (len(label)/2) - 1),
                    y=int(ymax * 1.05)
                ))

                time_to_mode[t] = mode

    except Exception as e:
        print "ERROR: error parsing bg file --",e
        parser.print_help()
        sys.exit(-1)

    if args.nimbus is not None:
        try:
            with open(args.nimbus) as f:
                initMode = 'XTCP'
                i = 0
                for l in f:
                    if i == 0:
                        initMode = l.split("initMode=")[1].split(" ")[0]



                    i+=1


        except Exception as e:
            print "ERROR: error parsing nimbus output file --",e
            parser.print_help()
            sys.exit(-1)
    

first_gplot_line = True
if not args.no_sum:
    gnuplot.stdin.write("""
{plot} u ($1-{adjust}):2 w lines lw {lw} lt 1 {sum_title},\\
""".format(
    adjust=xmin_s,
    plot=("plot '{}'".format(TMP_FILE) if first_gplot_line else "''"),
    sum_title=("ti 'sum'" if args.key else 'notitle'),
    lw=lw   
))
    first_gplot_line = False

col = 4
if flows_to_track:
    for flow in flows_to_track:
        gnuplot.stdin.write("{plot} u ($1-{adjust}):{col} w lines lw {lw} lt {lt} {flow},\\\n".format(
            adjust=xmin_s,
            plot=("plot '{}'".format(TMP_FILE) if first_gplot_line else "''"),
            col=col,
            lt=col if not first_gplot_line else 1,
            flow=("ti '{}'".format(flow_to_name[flow]) if args.key else 'notitle'),
            lw=lw
        ))
        col+=1
        first_gplot_line = False
if agg_names:
    for agg in agg_names:
        gnuplot.stdin.write("{plot} u ($1-{adjust}):{col} w lines lw {lw} lt {lt} {flow},\\\n".format(
            adjust=xmin_s,
            plot=("plot '{}'".format(TMP_FILE) if first_gplot_line else "''"),
            col=col,
            lt=col if not first_gplot_line else 1,
            flow=("ti '{}'".format(agg.format(
                util = "{:.1f}".format(100.0 * (bits_to_mbps(departure_sum[agg], duration) / avg_capacity))+"%"
            )) if args.key else 'notitle'),
            lw=lw
        ))
        col+=1
        first_gplot_line = False

if not args.no_delay:
    gnuplot.stdin.write("{plot} u ($1-{adjust}):3 w lines lw {lw} lt 3 notitle axes x1y2,\\\n".format(
    adjust=xmin_s,
    plot=("plot '{}'".format(TMP_FILE) if first_gplot_line else "''"),
    lw=lw
))
    first_gplot_line = False

if args.plot_expected:
    gnuplot.stdin.write("'{f}' u 1:3 with steps lw {lw} lc rgb \"gold\" {ti},\\\n".format(
        f=args.bg,
        lw=lw,
        ti=("ti 'Best Rate'" if args.key else 'notitle')
    ))

gnuplot.communicate()
gnuplot.wait()
###

if args.fake is None or not args.fake:
    outf.flush()
    outf.close()
    os.remove(TMP_FILE)
    os.system("epstopdf " + outfname)
    if not args.no_display:
	os.system("open " + outfname)
