#!/usr/bin/env python

import os
import sys
import time
import math
import threading
import signal
import cursor
import curses
from ascii_graph import Pyasciigraph


def rxtxcount(rxpath, txpath):
    rx1 = int(open(rxpath).read().replace("\n", ""))
    tx1 = int(open(txpath).read().replace("\n", ""))
    time.sleep(1)
    rx2 = int(open(rxpath).read().replace("\n", ""))
    tx2 = int(open(txpath).read().replace("\n", ""))
    return (
        abs(rx1 - rx2) * 0.001,
        abs(tx1 - tx2) * 0.001,
        rx2 * 0.000001,
        tx2 * 0.000001,
    )


# check if the network device exists and is operating
def checkdevice(sysdir, device):

    if device not in os.listdir(sysdir):
        return False

    if open(sysdir + device + "/operstate").read().replace("\n", "") != "up":
        return False
    return True


# find all network devices that are up
def finddevice(sysdir):
    for i in os.listdir(sysdir):
        if open(sysdir + i + "/operstate").read().replace("\n", "") == "up":
            return i


def checkos():
    if "linux" not in sys.platform:
        print("Not a linux machine.")
        sys.exit()


# clean up when sigterm or siginit is sent
def sigterm_handler(signal, frame):
    cursor.show()
    curses.echo()
    curses.nocbreak()
    curses.endwin()
    sys.exit()


def keypress(stdscr):
    while stdscr.getch() != ord("q"):
        time.sleep(0.2)


signal.signal(signal.SIGTERM, sigterm_handler)
signal.signal(signal.SIGINT, sigterm_handler)


def getdevice(sysdir):
    if (len(sys.argv) > 1) and "help" in sys.argv[1]:
        print("usage: netmon [network device]".format(sys.argv[0]))
        sys.exit()

    checkos()

    device = ""
    if len(sys.argv) > 1:
        for i in sys.argv:
            if checkdevice(sysdir, i):
                device = i

    if not device:
        device = finddevice(sysdir)

    if not device:
        print("no appropriate device found")
        sys.exit()

    return device


# init cursors env
def initscreen():
    cursor.hide()
    stdscr = curses.initscr()
    stdscr.clear()
    curses.start_color()
    curses.use_default_colors()
    cursor.hide()
    curses.noecho()
    curses.cbreak()
    stdscr.nodelay(True)
    return stdscr


def destroyscreen():
    cursor.show()
    curses.echo()
    curses.nocbreak()
    curses.endwin()


# init graph params
def initgraph(maxvalue, rows, columns):
    graph = Pyasciigraph(
        line_length=columns,
        min_graph_length=1,
        separator_length=1,
        multivalue=True,
        graphsymbol="*",
        float_format="{0:,.0f}",
        force_max_value=maxvalue,
    )
    return graph


def getcolor(maxvalue, sample):
    if 0 < maxvalue <= 100:
        return curses.COLOR_GREEN
    elif 100 < maxvalue <= 1000:
        return curses.COLOR_YELLOW
    elif maxvalue >= 1000:
        return curses.COLOR_RED
    elif maxvalue >= 10000:
        return curses.COLOR_MAGENTA
    else:
        return curses.COLOR_CYAN


def setmaxvalue(sample):
    val = max(len(str(int(math.ceil(sample[0])))), len(str(int(math.ceil(sample[1])))))
    return 10 ** (val)


def drawdata(sample, device, sysdir, stdscr, graph, maxvalue, columns):
    graphdata = graph.graph(
        "device: {0}".format(str(device)),
        [("KB/s Down", sample[0]), ("KB/s Up", sample[1])],
    )
    stdscr.addstr(0, 0, graphdata[0])
    stdscr.addstr(1, 0, graphdata[1])
    if sample[0] > sample[1]:
        curses.init_pair(1, getcolor(maxvalue, sample[0]), -1)
        stdscr.addstr(2, 0, graphdata[2], curses.color_pair(1))
        stdscr.addstr(3, 0, graphdata[3])
    elif sample[0] == sample[1]:
        curses.init_pair(1, curses.COLOR_CYAN, -1)
        stdscr.addstr(2, 0, graphdata[2], curses.color_pair(1))
        stdscr.addstr(3, 0, graphdata[3], curses.color_pair(1))
    else:
        curses.init_pair(1, getcolor(maxvalue, sample[1]), -1)
        stdscr.addstr(2, 0, graphdata[2])
        stdscr.addstr(3, 0, graphdata[3], curses.color_pair(1))
    stdscr.addstr(4, 0, "Total received: {0:.2f} MBs".format(sample[2]))
    stdscr.addstr(5, 0, "Total sent: {0:.2f} MBs".format(sample[3]))
    stdscr.refresh()


def run(sysdir, stdscr, device, graph, maxvalue, rows, columns):
    rxbytepath = sysdir + device + "/statistics/rx_bytes"
    txbytepath = sysdir + device + "/statistics/tx_bytes"
    thread = threading.Thread(target=keypress, args=(stdscr,))
    thread.daemon = True
    thread.start()
    try:
        while thread.is_alive():
            if curses.is_term_resized(rows, columns):
                rows, columns = stdscr.getmaxyx()
                graph = initgraph(maxvalue, rows, columns)
                curses.resizeterm(rows, columns)
                stdscr.refresh()
            sample = rxtxcount(rxbytepath, txbytepath)
            maxvalue = setmaxvalue(sample)
            graph = initgraph(maxvalue, rows, columns)
            drawdata(sample, device, sysdir, stdscr, graph, maxvalue, columns)
    except Exception as e:
        print(e)
        sigterm_handler(None, None)


if __name__ == "__main__":
    sysdir = "/sys/class/net/"
    device = getdevice(sysdir)
    stdscr = initscreen()
    maxvalue = 100.0
    rows, columns = stdscr.getmaxyx()
    graph = initgraph(maxvalue, rows, columns)
    run(sysdir, stdscr, device, graph, maxvalue, rows, columns)
    destroyscreen()
    sys.exit()
