#!/usr/bin/env python
import os
import psutil
import subprocess
import signal
import sys
import time

poll_timeout = 1
death_timeout = 10

# On Linux, could also use prctl to get a notification on parent death
def main():
    if len(sys.argv) <= 1:
        sys.stderr.write(f"Usage: {sys.argv[0]} <subcommand to run>\n")
        return 1

    if sys.platform == 'linux':
        import ctypes

        def preexec_fn():
            PR_SET_PDEATHSIG = 1
            ctypes.cdll["libc.so.6"].prctl(PR_SET_PDEATHSIG, int(signal.SIGKILL))
    else:
        def preexec_fn():
            pass

    child = subprocess.Popen(sys.argv[1:], close_fds=False, preexec_fn=preexec_fn)
    def sigint(sig, frame):
        child.kill()
        child.wait()
        os._exit(0)

    signal.signal(signal.SIGINT, sigint)

    # Let the process die naturally
    while True:
        if child.poll() is not None:
            return child.returncode
        elif os.getppid() == 1:
            break
        time.sleep(poll_timeout)

    # Try to kill (and also carefully kill any descendents)
    proc = psutil.Process()
    descendents = proc.children(recursive=True)

    # Send TERM to our direct subprocess
    child.terminate()

    # Wait for it to die on its own, or time out and kill it directly.
    start = time.time()
    while time.time() - start < death_timeout:
        if child.poll() is not None:
            break

    # Now kill any leftover processes
    for desc in descendents:
        try:
            desc.kill()
        except psutil.NoSuchProcess:
            pass

    child.wait()
    return child.returncode


if __name__ == "__main__":
    sys.exit(main())
