#!/usr/bin/env python
import os
import psutil
import subprocess
import signal
import sys
import time

poll_timeout = 1
death_timeout = 10

def close_fds(exclude={0, 1, 2}):
    if sys.platform == 'linux':
        fds = [int(fd) for fd in os.listdir("/proc/self/fd")]
    else:
        # Best effort guess
        fds = list(range(1024))
    for fd in fds:
        if fd in exclude:
            continue
        try:
            os.close(fd)
        except OSError:
            pass

# On Linux, could also use prctl to get a notification on parent death
def main():
    if len(sys.argv) >= 1 and sys.argv[1] == "--term":
        sys.argv = sys.argv[1:]
        term = True
    else:
        term = False

    def terminate(proc):
        if term:
            proc.terminate()
        else:
            proc.kill()
    sig = signal.SIGTERM if term else signal.SIGKILL

    if len(sys.argv) <= 1:
        sys.stderr.write(f"Usage: {sys.argv[0]} <subcommand to run>\n")
        return 1

    if sys.platform == 'linux':
        import ctypes

        def preexec_fn():
            PR_SET_PDEATHSIG = 1
            ctypes.cdll["libc.so.6"].prctl(PR_SET_PDEATHSIG, int(sig))
    else:
        def preexec_fn():
            pass

    child = subprocess.Popen(sys.argv[1:], close_fds=False, preexec_fn=preexec_fn)
    close_fds()

    def sigint(sig, frame):
        terminate(child)
        child.wait()
        os._exit(0)

    signal.signal(signal.SIGINT, sigint)

    # Let the process die naturally
    while True:
        if child.poll() is not None:
            return child.returncode
        elif os.getppid() == 1:
            break
        time.sleep(poll_timeout)

    # Try to kill (and also carefully kill any descendents)
    proc = psutil.Process()
    descendents = proc.children(recursive=True)

    # Send TERM to our direct subprocess, giving it a chance to
    # cleanly exit.
    child.terminate()

    # Wait for it to die on its own, or time out and kill it directly.
    start = time.time()
    while time.time() - start < death_timeout:
        if child.poll() is not None:
            break

    # Now hard-kill any leftover processes
    for desc in descendents:
        try:
            desc.kill()
        except psutil.NoSuchProcess:
            pass

    child.wait()
    return child.returncode


if __name__ == "__main__":
    sys.exit(main())
