#! /usr/bin/env python3

"""
Start a GRN GPU Master, run the command with GRN monitoring, then terminate the GPU Master.
"""
import time
from argparse import ArgumentParser
import shlex
import subprocess
import logging

from rich.logging import RichHandler

import grn
from grn.utils.utils import find_gpu_master_address, is_server_available


log = logging.getLogger(__file__)
LOGGING_FORMAT = '%(message)s'


def run():
    parser = ArgumentParser()
    parser.add_argument('-d', '--debug', action='store_true', 
        help='Flag to run the GRN GPU Master in debug mode.')
    parser.add_argument('cmds', type=str, nargs='+')
    args = parser.parse_args()

    # Determine if the GRN GPU master is running.
    try:
        addr = find_gpu_master_address()
    except FileNotFoundError:
        server_exists = False
    else:
        server_exists = is_server_available(addr)

    # If the GPU Master isn't running, then launch it and wait until it initializes.
    master_proc = None
    if not server_exists:
        logging.basicConfig(
            level=logging.WARNING,
            format=LOGGING_FORMAT, 
            handlers=[RichHandler(
                locals_max_string=None, 
                tracebacks_word_wrap=False)])

        grn_start_cmd = ['grn-start']
        if args.debug:
            grn_start_cmd.append('--debug')
        master_proc = subprocess.Popen(grn_start_cmd)

        # Wait for the server to start.
        # TODO: Implement a timeout?
        while True:
            try:
                if not is_server_available(find_gpu_master_address()):
                    continue
            except FileNotFoundError:
                time.sleep(0.5)
            else:
                break

    # If multiple commands are passed, recursively start a subprocess for each one.
    if len(args.cmds) > 1:
        procs = []
        for cmd in args.cmds:
            grn_run_cmd = ['grn-run', f'""{cmd}""']  # Why the fuck does subprocess need double quotes?
            procs.append( subprocess.Popen(grn_run_cmd) )
        for p in procs:
            p.wait()

    # Once we've reached the single command level, hook up the job to GRN and launch it.
    else:
        @grn.job(jobstr=args.cmds[0])
        def job() -> None:
            subprocess.Popen(shlex.split(args.cmds[0])).wait()
        job()
        
    if master_proc:
        # Only the top level of the recursion reaches this point.
        # Signal the GPU Master to clean up and terminate.
        master_proc.terminate()


if __name__ == '__main__':
    run()