#!/usr/bin/python3

import argparse
import json
import re
import os
import subprocess
import requests
import docker
from faucetconfrpc.faucetconfrpc_client_lib import FaucetConfRpcClient
from graphviz import Digraph


class GraphDovesnapException(Exception):
    pass


class GraphDovesnap:

    DOVESNAP_NAME = 'dovesnap_plugin'
    OVS_NAME = 'dovesnap_ovs'
    DRIVER_NAME = 'ovs'
    OFP_LOCAL = 4294967294
    DOCKER_URL = 'unix://var/run/docker.sock'
    PATCH_PREFIX = 'ovp'
    VM_PREFIX = 'vnet'
    DOVESNAP_MIRROR = '99'

    def __init__(self, args):
        self.args = args

    def _get_named_container(self, client, name, strict=True):
        for container in client.containers(filters={'name': name}):
            if not strict:
                return container
            for container_name in container['Names']:
                if name in container_name:
                    return container
        return None

    def _get_named_container_hi(self, client_hi, name, strict=True):
        for container in client_hi.containers.list(filters={'name': name}):
            if not strict:
                return container
            if container.name == name:
                return container
        return None

    def _scrape_container_cmd(self, name, cmd, strict=True):
        client_hi = docker.DockerClient(base_url=self.DOCKER_URL)
        container = self._get_named_container_hi(client_hi, name, strict=strict)
        try:
            (dump_exit, output) = container.exec_run(cmd)
            if dump_exit == 0:
                return output.decode('utf-8').splitlines()
        except subprocess.CalledProcessError:
            pass
        return None

    def _scrape_cmd(self, cmd):
        try:
            output = subprocess.check_output(cmd)
            return output.decode('utf-8').splitlines()
        except (subprocess.CalledProcessError, FileNotFoundError):
            return []

    def _scrape_ovs(self, cmd):
        return self._scrape_container_cmd(self.OVS_NAME, cmd, strict=False)

    def _get_vm_options(self, faucetconfrpc_client, network, ofport):
        vm_options = []
        conf = faucetconfrpc_client.get_config_file()
        interfaces = conf['dps'][network]['interfaces']
        interface = interfaces[ofport]
        acls_in = interface.get('acls_in', None)
        if acls_in:
            vm_options.append('portacl: %s' % ','.join(acls_in))
        mirror_interface = interfaces.get(self.DOVESNAP_MIRROR, None)
        if mirror_interface:
            mirrored_ports = mirror_interface.get('mirror', [])
            if ofport in mirrored_ports:
                vm_options.append('mirror: true')
        return '\n'.join(vm_options)

    def _scrape_external_iface(self, name):
        desc = [name, '', 'External Interface']
        output = self._scrape_cmd(['ifconfig', name])
        if output:
            mac = output[1].split()[1]
            desc.append(mac)
        return '\n'.join(desc)

    def _network_lookup(self, name):
        output = self._scrape_cmd(['nslookup', name])
        if output:
            hostname, address = output[4:-1]
            hostname = hostname.split('\t')[1]
            address = address.split(': ')[1]
            return hostname, address
        return None, None

    def _scrape_vm_iface(self, name):
        desc = ['', 'Virtual Machine', name]
        output = self._scrape_cmd(['virsh', 'list'])
        if output:
            vm_names = output[2:-1]
            for vm_list in vm_names:
                vm_name = vm_list.split()[1]
                vm_iflist = self._scrape_cmd(['virsh', 'domiflist', vm_name])
                ifaces = vm_iflist[2:-1]
                iface_macs = {iface.split()[0]: iface.split()[4] for iface in ifaces}
                mac = iface_macs.get(name, None)
                if mac is not None:
                    desc.append(mac)
                    hostname, address = self._network_lookup(vm_name)
                    if hostname is not None:
                        desc.insert(0, hostname)
                        desc.append(f'{address}/24')
                    break
        return '\n'.join(desc)

    def _get_matching_lines(self, lines, re_str):
        match_re = re.compile(re_str)
        matching_lines = []
        for line in lines:
            match = match_re.match(line)
            if match:
                matching_lines.append(match)
        return matching_lines

    def _scrape_container_iface(self, container_id):
        lines = self._scrape_container_cmd(
            self.DOVESNAP_NAME, ['ip', 'netns', 'exec', container_id, 'ip', '-o', 'link', 'show'], strict=False)
        results = []
        if lines is not None:
            matching_lines = self._get_matching_lines(
                lines, r'^(\d+):\s+([^\@]+)\@if(\d+):.+link\/ether\s+(\S+).+$')
            for match in matching_lines:
                iflink = int(match[1])
                ifname = match[2]
                peeriflink = int(match[3])
                mac = match[4]
                results.append((ifname, mac, iflink, peeriflink))
        return results

    def _scrape_bridge_ports(self, bridgename):
        lines = self._scrape_ovs(['ovs-ofctl', 'dump-ports-desc', bridgename])
        port_desc = {}
        if lines is not None:
            matching_lines = self._get_matching_lines(
                lines, r'^\s*(\d+|LOCAL)\((\S+)\).+$')
            for match in matching_lines:
                port = match[1]
                desc = match[2]
                if port == 'LOCAL':
                    port = self.OFP_LOCAL
                port = int(port)
                port_desc[desc] = port
        return port_desc

    def _scrape_all_bridge_ports(self):
        all_port_desc = {}
        lines = self._scrape_ovs(['ovs-vsctl', 'list-br'])
        if lines is not None:
            matching_lines = self._get_matching_lines(
                lines, r'^(\S+)$')
            for match in matching_lines:
                bridgename = match[0]
                all_port_desc[bridgename] = self._scrape_bridge_ports(bridgename)
        return all_port_desc

    def _scrape_host_veth(self):
        return self._scrape_cmd(['ip', '-o', 'link', 'show', 'type', 'veth'])

    def _scrape_patch_veths(self):
        patch_veths = {}
        lines = self._scrape_host_veth()
        if lines is not None:
            matching_lines = self._get_matching_lines(
                lines,
                r'^\d+:\s+(%s[^\@]+)\@([^\:\s]+).+link\/ether\s+(\S+).+$' % self.PATCH_PREFIX)
            for match in matching_lines:
                ifname = match[1]
                peerifname = match[2]
                mac = match[3]
                patch_veths[ifname] = (peerifname, mac)
            assert len(patch_veths) % 2 == 0
        return patch_veths

    def _get_lb_port(self, network):
        return network['Options'].get('ovs.bridge.lbport', self.DOVESNAP_MIRROR)

    def _get_container_args(self, container_inspect):
        args = {}
        for arg_str in container_inspect['Config']['Cmd']:
            arg_str = arg_str.lstrip('-')
            arg_l = arg_str.split('=')
            if len(arg_l) > 1:
                args[arg_l[0]] = arg_l[1]
            else:
                args[arg_l[0]] = ""
        return args

    def output_graph(self, nodes, edges):
        dot = Digraph()
        for node, node_labels in nodes.items():
            dot.node(node, '\n'.join(node_labels))
        for edge_a, edge_b, edge_ports in edges:
            edge_label = ' : '.join([str(port) for port in edge_ports])
            dot.edge(edge_a, edge_b, edge_label)
        dot.format = 'png'
        dot.render(self.args.output)
        # leave only PNG
        os.remove(self.args.output)

    def build_graph(self):
        networks_json = {}
        for status_addr in self.args.status_addrs.split(','):
            status_url = 'http://%s/networks' % status_addr
            resp = requests.get(status_url)
            networks_json.update(json.loads(resp.text))
        client = docker.APIClient(base_url=self.DOCKER_URL)
        if not client.ping():
            raise GraphDovesnapException('cannot connect to docker')
        dovesnap = self._get_named_container(client, self.DOVESNAP_NAME)
        if not dovesnap:
            raise GraphDovesnapException('cannot find dovesnap container')
        faucetconfrpc_client = FaucetConfRpcClient(
            self.args.key, self.args.cert, self.args.ca, ':'.join([self.args.server, self.args.port]))
        if not faucetconfrpc_client:
            raise GraphDovesnapException('cannot connect to faucetconfrpc')
        nodes = {}
        edges = []

        # TODO: get options list from status server
        dovesnap_args = self._get_container_args(client.inspect_container(dovesnap['Id']))
        # TODO: poll OVS from other hosts.
        patch_veths = self._scrape_patch_veths()
        all_port_desc = self._scrape_all_bridge_ports()
        unresolved_links = []
        network_id_name = {}
        for network_id, network in networks_json.items():
            network_name = network['NetworkName']
            bridgename = network['BridgeName']
            mode = network['Mode']
            network_inspect = client.inspect_network(network_id)
            options = ['%s: %s' % (option.split('.')[-1], optionval)
                for option, optionval in network_inspect['Options'].items()]
            network_id_name[bridgename] = network_id
            nodes[network_id] = [network_name, bridgename] + options
            container_ports = set()
            for container in network['Containers'].values():
                container_id = container['Id']
                container_name = container['Name']
                ip = container.get('HostIP', None)
                mac = container['MacAddress']
                ofport = container['OFPort']
                labels = container['Labels']
                # TODO: need inside-container eth name from status server
                ifname = 'eth0'
                display_labels = ['%s: %s' % (label.split('.')[-1], labelval)
                    for label, labelval in labels.items()]
                host_label = [container_name, '', 'Container', ifname, mac]
                if ip:
                    host_label.append(ip)
                host_label.extend(display_labels)
                container_ports.add(ofport)
                nodes[container_id] = host_label
                edges.append((network_id, container_id, [ofport]))
            for br_desc, ofport in all_port_desc[bridgename].items():
                if ofport in container_ports:
                    continue
                if ofport == self.OFP_LOCAL:
                    if mode == 'nat':
                        edges.append((network_id, 'NAT', [self.OFP_LOCAL]))
                elif br_desc in patch_veths:
                    unresolved_links.append((bridgename, br_desc, ofport))
                else:
                    if br_desc.startswith(self.VM_PREFIX):
                        vm_desc = self._scrape_vm_iface(br_desc)
                        vm_options = self._get_vm_options(
                            faucetconfrpc_client, network['Name'], ofport)
                        nodes[br_desc] = ['', vm_desc, vm_options]
                    else:
                        # TODO: scrape external iface name and MAC from status server
                        external_desc = self._scrape_external_iface(br_desc)
                        nodes[br_desc] = [external_desc]
                    edges.append((network_id, br_desc, [ofport]))

        non_container_bridges = set()

        # wire up links to non container bridges.
        for bridgename, br_desc, ofport in unresolved_links:
            network_id = network_id_name[bridgename]
            peer_br_desc = patch_veths[br_desc][0]
            for peer_bridgename, port_desc in all_port_desc.items():
                if peer_br_desc in port_desc:
                    peer_ofport = port_desc[peer_br_desc]
                    edges.append((network_id, peer_bridgename, [ofport, peer_ofport]))
                    if peer_bridgename not in network_id_name:
                        non_container_bridges.add(peer_bridgename)
                    break

        # resolve any remaining ports, on non container bridges.
        for bridgename in non_container_bridges:
            for br_desc, ofport in all_port_desc[bridgename].items():
                if ofport == self.OFP_LOCAL:
                    continue
                if br_desc in patch_veths:
                    continue
                if br_desc == dovesnap_args.get('mirror_bridge_in', ''):
                    edges.append((br_desc, bridgename, [ofport]))
                else:
                    edges.append((bridgename, br_desc, [ofport]))

        self.output_graph(nodes, edges)


def main():
    parser = argparse.ArgumentParser(
        description='Dovesnap Graph - A dot file output graph of VMs, containers, and networks controlled by Dovesnap')
    parser.add_argument('--ca', '-a', default='/opt/faucetconfrpc/faucetconfrpc-ca.crt',
                        help='FaucetConfRPC server certificate authority file')
    parser.add_argument('--cert', '-c', default='/opt/faucetconfrpc/faucetconfrpc.crt',
                        help='FaucetConfRPC server cert file')
    parser.add_argument('--key', '-k', default='/opt/faucetconfrpc/faucetconfrpc.key',
                        help='FaucetConfRPC server key file')
    parser.add_argument('--port', '-p', default='59999',
                        help='FaucetConfRPC server port')
    parser.add_argument('--server', '-s', default='faucetconfrpc',
                        help='FaucetConfRPC server name')
    parser.add_argument('--output', '-o', default='dovesnapviz',
                        help='Output basename of image to write')
    parser.add_argument('--status_addrs', '-u', default='localhost:9401',
                        help='Command separated list of dovesnap status URLs to scrape')
    args = parser.parse_args()
    g = GraphDovesnap(args)
    g.build_graph()


if __name__ == "__main__":
    main()
