#!/usr/bin/env python

"""
rx-buffers-increase utils, that finds and sets compromise-value
between avoiding dropped/missing pkts and keeping a latency low.
"""

from sys import argv
from os import getenv, system
from unittest import TestCase
from unittest import main as test_main
from subprocess import Popen, PIPE

__author__ = 'Oleg Strizhechenko <oleg.strizhechenko@gmail.com>'


class RxBuffersIncreaser(object):

    """
    1. Determines what size of buffers available
    2. Evaluate one that fits for our purposes
    3. Apply it
    """

    def __init__(self, dev=None, upper_bound=2048):
        self.dev = dev
        self.upper_bound = upper_bound
        self.current = 0
        self.maximum = 0
        self.prefered = None

    def investigate(self):
        """ get maximum and current rx ring buffers values via ethtool """
        extract_value = lambda s: int(s.strip('RX:\t\n'))
        process = Popen(['ethtool', '-g', self.dev], stdout=PIPE, stderr=PIPE)
        ethtool_buffers, stderr = process.communicate()
        if process.returncode != 0:
            raise LookupError, stderr
        ethtool_buffers = ethtool_buffers.split('\n')
        self.maximum = extract_value(ethtool_buffers[2])
        self.current = extract_value(ethtool_buffers[7])

    def determine(self, current=None, maximum=None):
        """ evaluate most fitting rx ring buffer's fize """
        if not current:
            current = self.current
        if not maximum:
            maximum = self.maximum
        if current > self.upper_bound:
            return current
        if maximum < self.upper_bound:
            return maximum
        return max(current, min(self.upper_bound, maximum / 2))

    def apply(self):
        """ doing all the job, applying new buffer's size if required """
        self.investigate()
        self.determine()
        if self.prefered == self.current:
            return
        system('ethtool -G {0} rx {1}'.format(self.dev, self.prefered))


class RxBuffersIncreaserTest(TestCase):

    def setUp(self):
        self.rxbi = RxBuffersIncreaser(upper_bound=2048)

    def test_4096(self):
        self.assertEqual(self.rxbi.determine(256, 4096), 2048)
        self.assertEqual(self.rxbi.determine(512, 4096), 2048)
        self.assertEqual(self.rxbi.determine(2048, 4096), 2048)
        self.assertEqual(self.rxbi.determine(3072, 4096), 3072)
        self.assertEqual(self.rxbi.determine(4096, 4096), 4096)

    def test_511(self):
        self.assertEqual(self.rxbi.determine(200, 511), 511)
        self.assertEqual(self.rxbi.determine(511, 511), 511)
        self.assertEqual(self.rxbi.determine(400, 511), 511)

    def test_8096(self):
        self.assertEqual(self.rxbi.determine(200, 8096), 2048)
        self.assertEqual(self.rxbi.determine(2048, 8096), 2048)
        self.assertEqual(self.rxbi.determine(3000, 8096), 3000)
        self.assertEqual(self.rxbi.determine(8096, 8096), 8096)


def main():
    upper_bound = int(getenv('RX_UPPER_BOUND', '2048'))
    dev = argv[1].split(';')[0]
    RxBuffersIncreaser(dev, upper_bound).apply()


if __name__ == '__main__':
    if len(argv) == 1:
        test_main()
    else:
        main()
