#!/usr/bin/env python
# encoding: utf-8

""" Submits the same job across multiple nodes
on a cluster with a shared filesystem, and no queuing,
handling killed jobs and crashes.
"""

from traceback import print_exc
from sys import stdout
import argparse
import time
import multiprocessing as mp
import subprocess as sp
from os import getcwd
import psutil

from matador.utils.print_utils import print_notify, print_warning


class Hive:
    """ Implements the spawning of workers. """

    def __init__(self, **kwargs):
        self.args = kwargs
        self.command = self.args.get("command")
        self.nodes = self.args.get("nodes")
        self.prefix = self.args.get("prefix")
        self.sleep = self.args.get("sleep", 1)
        if self.prefix is None:
            self.prefix = "node"
        if len(set(self.nodes)) < len(list(self.nodes)):
            print_warning("Skipping duplicate nodes...")
        self.nodes = set(self.nodes)

    def spawn(self):
        """ Spawn processes to perform calculations (borrowed from run3). """
        procs = []
        for node in self.nodes:
            procs.append(mp.Process(target=self.run_command, args=[node]))
        try:
            for proc in procs:
                proc.start()
                if len(self.nodes) > 1:
                    time.sleep(self.sleep)
        except (KeyboardInterrupt, SystemExit, RuntimeError):
            print_exc()
            for proc in procs:
                proc.terminate()
            exit("Killing running jobs and exiting...")

    def run_command(self, node):
        """ Parse command and run on particular node. """
        cwd = getcwd()
        compute_command = self.command.replace(
            "$ALL_CORES", str(psutil.cpu_count(logical=True))
        )
        print_notify(
            "Executing {} on {}{}...".format(compute_command, self.prefix, node)
        )
        process = sp.Popen(
            [
                "ssh",
                "{}{}".format(self.prefix, node),
                "cd",
                "{};".format(cwd),
                compute_command,
            ],
            stdout=stdout,
            shell=False,
        )
        return process


if __name__ == "__main__":
    PARSER = argparse.ArgumentParser(
        prog="oddjob",
        description="Run a single script on multiple nodes, numbered node<n>.",
        epilog="Written by Matthew Evans (2017).",
    )
    PARSER.add_argument(
        "command",
        type=str,
        default=1,
        help="command to run, with arguments - must be apostrophized. \
                        Use the $ALL_CORES macro to use, unsurprisingly, all the cores \
                        on the node in question. e.g. oddjob 'pyairss -c $ALL_CORES -ha \
                        <seed>' -n 1 16",
    )
    PARSER.add_argument(
        "-s",
        "--sleep",
        type=int,
        help="sleep for this many seconds in between submissions to nodes (DEFAULT: 1 s)",
    )
    PARSER.add_argument(
        "-n",
        "--nodes",
        type=str,
        nargs="+",
        required=True,
        help="list node numbers to run job on with space delimiters, e.g. \
                        -n 3 14 15",
    )
    PARSER.add_argument("-p", "--prefix", type=str, help="prefix for hostname")
    PARSER.add_argument("-d", "--debug", action="store_true", help="debug output")
    ARGS = PARSER.parse_args()
    RUNNER = Hive(
        command=ARGS.command, nodes=ARGS.nodes, debug=ARGS.debug, prefix=ARGS.prefix, sleep=ARGS.sleep
    )
    try:
        RUNNER.spawn()
    except (KeyboardInterrupt, SystemExit):
        exit("Exiting top-level...")
