#!/usr/bin/env python3
import sys
import re
import time
import os
import socket
import struct
import select
import argparse
import subprocess
import traceback

from systemd import journal, daemon

from pathlib import Path

RX_FAIL = re.compile(r'Failed password for (?:invalid user )?([a-z0-9]+?) from (\d+\.\d+\.\d+\.\d+)')
RX_ACCEPT = re.compile(r'Accepted (?:password|publickey) for ([a-z0-9]+?) from (\d+\.\d+\.\d+\.\d+)')

def encode_len(s):
    s = s.encode('utf8')
    return struct.pack(">H", len(s)) + s

def readall(sock, rdlen):
    rbuf = b''
    while len(rbuf) < rdlen:
        data = sock.recv(rdlen - len(rbuf))
        if not data:
            return rbuf
        rbuf += data
    return rbuf

def send_command(sockpath, args):
    try:
        data = (struct.pack(">I", 1) + b''.join(encode_len(s) for s in args) +
                struct.pack(">H", 65535))

        sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
        with sock:
            sock.connect(sockpath)
            sock.sendall(data)

            hdr = readall(sock, 10)
            tag, status, leng = struct.unpack(">IIH", hdr)
            readall(sock, leng)
            sock.shutdown(socket.SHUT_RDWR)

    except OSError:
        traceback.print_exc()

def read_messages(rdr, path):
    while True:
        msg = rdr.get_next()
        if not msg:
            return

        text = msg['MESSAGE']
        m = RX_FAIL.match(text)
        if m:
            user, addr = m.groups()
            send_command(path, ('sshkill', 'badlogin', addr, user))

        m = RX_ACCEPT.match(text)
        if m:
            user, addr = m.groups()
            send_command(path, ('sshkill', 'goodlogin', addr, user))


def wait_messages(rdr, poll):
    while True:
        poll.poll()
        res = rdr.process()
        if res == journal.INVALIDATE:
            rdr.seek_tail()
            rdr.get_previous()
        elif res == journal.APPEND:
            return

def main():
    p = argparse.ArgumentParser(description='Watch the output of SSH and ban IP addresses that send bad passwords')
    
    p.add_argument('-i', '--install-service',
                   nargs='?', metavar='PATH', const='/lib/systemd/system',
                   help='Install as a systemd service')
    
    p.add_argument('-s', '--controlsock', dest='controlpath',
                     action='store', metavar='PATH',
                     default='/var/run/firewall-control',
                     help='admin control socket')
    
    args = p.parse_args()

    if args.install_service:
        from firewall import service
        service.install_unit(
            Path(args.install_service),
            'firewall_ssh_logmon',
            'Firewall SSH Log Monitor',
            'After=firewall.service',
            Path(__file__).resolve(),
            ''
        )
        return
    
    rdr = journal.Reader()
    rdr.add_match(_SYSTEMD_UNIT='ssh.service')
    rdr.add_match(SYSLOG_FACILITY=4)
    rdr.this_boot()
    rdr.seek_tail()
    rdr.get_previous()

    poll = select.poll()
    poll.register(rdr.fileno(), rdr.get_events())

    daemon.notify('READY=1')

    while True:
        read_messages(rdr, args.controlpath)
        wait_messages(rdr, poll)

if __name__ == '__main__':
    main()
