#!/usr/bin/env python
# encoding: utf-8
""" Submits the same job across multiple nodes
on a cluster with a shared filesystem, and no queuing,
handling killed jobs and crashes.
"""

from traceback import print_exc
from sys import stdout
import argparse
import multiprocessing as mp
import subprocess as sp
from os import getcwd
import psutil

from matador.utils.print_utils import print_notify, print_warning


class Hive:
    """ Implements the spawning of workers. """
    def __init__(self, **kwargs):
        self.args = kwargs
        self.command = self.args.get('command')
        self.nodes = self.args.get('nodes')
        self.prefix = self.args.get('prefix')
        if self.prefix is None:
            self.prefix = 'node'
        if len(set(self.nodes)) < len(list(self.nodes)):
            print_warning('Skipping duplicate nodes...')
        self.nodes = set(self.nodes)

    def spawn(self):
        """ Spawn processes to perform calculations (borrowed from run3). """
        procs = []
        for node in self.nodes:
            procs.append(mp.Process(target=self.run_command, args=[node]))
        try:
            for proc in procs:
                proc.start()
        except(KeyboardInterrupt, SystemExit, RuntimeError):
            print_exc()
            for proc in procs:
                proc.terminate()
            exit('Killing running jobs and exiting...')

    def run_command(self, node):
        """ Parse command and run on particular node. """
        cwd = getcwd()
        compute_command = self.command.replace('$ALL_CORES', str(psutil.cpu_count(logical=True)))
        print_notify('Executing {} on {}{}...'.format(compute_command, self.prefix, node))
        process = sp.Popen(['ssh', '{}{}'.format(self.prefix, node),
                            'cd', '{};'.format(cwd),
                            compute_command], stdout=stdout, shell=False)
        return process


if __name__ == '__main__':
    PARSER = argparse.ArgumentParser(
        prog='oddjob',
        description='Run a single script on multiple nodes, numbered node<n>.',
        epilog='Written by Matthew Evans (2017).')
    PARSER.add_argument('command', type=str,
                        help='command to run, with arguments - must be apostrophized. \
                        Use the $ALL_CORES macro to use, unsurprisingly, all the cores \
                        on the node in question. e.g. oddjob \'pyairss -c $ALL_CORES -ha \
                        <seed>\' -n 1 16')
    PARSER.add_argument('-n', '--nodes', type=str, nargs='+', required=True,
                        help='list node numbers to run job on with space delimiters, e.g. \
                        -n 3 14 15')
    PARSER.add_argument('-p', '--prefix', type=str,
                        help='prefix for hostname')
    PARSER.add_argument('-d', '--debug', action='store_true',
                        help='debug output')
    ARGS = PARSER.parse_args()
    RUNNER = Hive(command=ARGS.command, nodes=ARGS.nodes, debug=ARGS.debug, prefix=ARGS.prefix)
    try:
        RUNNER.spawn()
    except(KeyboardInterrupt, SystemExit):
        exit('Exiting top-level...')
