#!/usr/bin/env python
"""
Copyright (C) 2019  Universite catholique de Louvain, Belgium.

This file is part of CP3SlurmUtils.

CP3SlurmUtils is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

CP3SlurmUtils is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with CP3SlurmUtils.  If not, see <http://www.gnu.org/licenses/>.
"""

import copy
import os
import re
import signal
import subprocess
import sys

from optparse import OptionParser

import CP3SlurmUtils.Py2Py3Helpers
from CP3SlurmUtils import __version__ as version
from CP3SlurmUtils.InteractiveQuestion import interactiveYesNoQuestion
from CP3SlurmUtils.InteractiveQuestion import questionTimeoutHandler
from CP3SlurmUtils.InteractiveQuestion import TimeoutException
from CP3SlurmUtils.SubmitUtils import submit


#================================================================================
# 1. Define/read command line options and arguments
#================================================================================

parser = OptionParser(
    description="""Script to resubmit failed jobs on a SLURM cluster that were submitted,
with the slurm_submit script.
The script takes as input a list of slurm job ids, parses the ~/.slurm_jobs file
searching for the input job ids in order to map existing batch scripts to input job ids,
and submits new jobs eventually creating new batch scripts.""",
    usage="Usage: %prog [options]",
    version="%prog {}".format(version),
    add_help_option=True
)

parser.add_option("-d", "--debug",
                  action = "store_true",
                  dest = "debug",
                  default = False,
                  help = "print all messages, be verbose")

parser.add_option("-q", "--quiet",
                  action = "store_true",
                  dest = "quiet",
                  default = False,
                  help = "print only most relevant messages")

parser.add_option("-j", "--jobs",
                  type = "string",
                  dest = "jobs",
                  default = "",
                  help = "comma separated list of slurm job ids to resubmit")

parser.add_option("-t", "--states",
                  type = "string",
                  dest = "states",
                  default = "failed,boot_fail,node_fail,deadline,preempted,timeout",
                  help = "comma separated list of job states:exitcodes for resubmission")

parser.add_option("--samenodes",
                  action = "store_true",
                  dest = "samenodes",
                  default = False,
                  help = "resubmit jobs on same nodes")

parser.add_option("--nodeswhitelist",
                  type = "string",
                  dest = "nodeswhitelist",
                  default = "",
                  help = "comma separated list of nodes to allow in the allocation")

parser.add_option("--nodesblacklist",
                  type = "string",
                  dest = "nodesblacklist",
                  default = "",
                  help = "comma separated list of nodes to exclude from the allocation")

parser.add_option("-s", "--submit",
                  action = "store_true",
                  dest = "submit",
                  default = False,
                  help = "submit the new jobs")

parser.add_option("-y", "--yes",
                  action = "store_true",
                  dest = "yes",
                  default = False,
                  help = "skip all interactive questions positively")

(opts, args) = parser.parse_args()

if not opts.jobs:
    parser.error("option --jobs (-j) is mandatory")
    sys.exit(1)

if opts.debug and opts.quiet:
    parser.error("options --debug (-d) and --quiet (-q) are mutually exclusive")
    sys.exit(1)

if args:
    parser.error("command takes no arguments")

jobIds = []
for v in opts.jobs.split(','):
    if v == '':
        continue
    if not re.match('^[1-9][0-9]*(_[1-9][0-9]*)?(-[1-9][0-9]*(_[1-9][0-9]*)?)?$', v):
        print("ERROR: Option --jobs (-j) has an invalid value: '%s'." % (v))
        sys.exit(1)
    if v not in jobIds:
        jobIds.append(v)

if not jobIds:
    if not opts.quiet:
        print("No (valid) job ids to resubmit were given.")
    sys.exit(0)

jobIdsFull = []

for v in jobIds:
    if '-' in v:
        jobId1, jobId2 = v.split('-')
        if ('_' in jobId1) != ('_' in jobId2):
            print("ERROR: Option --jobs (-j) has an invalid value: '%s'." % (v))
            sys.exit(1)
        if '_' in jobId1:
            jobId1, taskId1 = jobId1.split('_')
            jobId2, taskId2 = jobId2.split('_')
            if jobId1 != jobId2:
                print("ERROR: Option --jobs (-j) has an invalid value: '%s'." % (v))
                sys.exit(1)
            if int(taskId2) < int(taskId1):
                print("ERROR: Option --jobs (-j) has an invalid value: '%s'." % (v))
                sys.exit(1)
            jobIdsFull += [jobId1 + '_' + str(i) for i in range(int(taskId1), int(taskId2)+1)]
        else:
            if int(jobId2) < int(jobId1):
                print("ERROR: Option --jobs (-j) has an invalid value: '%s'." % (v))
                sys.exit(1)
            jobIdsFull += list(map(str, range(int(jobId1), int(jobId2)+1)))
    else:
        jobIdsFull.append(v)
        
jobIds = list(map(str, sorted(list(map(int, list(set([v.split('_')[0] for v in jobIdsFull])))))))

slurmJobStates = ['BOOT_FAIL', 'CANCELLED', 'COMPLETED', 'CONFIGURING', 'COMPLETING', 'DEADLINE', 'FAILED', 'NODE_FAIL', 'PENDING', 'PREEMPTED', 'RUNNING', 'RESIZING', 'SUSPENDED', 'TIMEOUT']
for state in opts.states.split(','):
    if state.split(':')[0].upper() not in slurmJobStates:
        print("ERROR: Option --states (-t) has an invalid value: '%s'." % (state) + \
              " Valid slurm job states are (case-insensitive): %s" % (slurmJobStates))
        sys.exit(1)
    if ':' in state:
        if not re.match('^[0-9]{1,3}$', state.split(':')[1]):
            print("ERROR: Option --states (-t) has an invalid exit code value: '%s'." % (state) + \
                  " Valid exit codes are 1-, 2- or 3-digit integer numbers.")
            sys.exit(1)
resubmitJobStates = list(map(str.upper, opts.states.split(',')))

if opts.nodeswhitelist and opts.nodesblacklist:
    print("ERROR: Options --nodeswhitelist and --nodesblacklist can not be used at the same time.")
    sys.exit(1)

if opts.samenodes and opts.nodesblacklist:
    print("ERROR: Options --samenodes and --nodesblacklist can not be used at the same time.")
    sys.exit(1)

if opts.nodeswhitelist:
    if not re.match('^mb-[a-z]{3}[0-9]{3}(,mb-[a-z]{3}[0-9]{3})*$', opts.nodeswhitelist):
        print("ERROR: Option --nodeswhitelist has an invalid value." + \
              " It should be a comma separated list of node names. E.g.: mb-har025,mb-har026,mb-sab034")
        sys.exit(1)

if opts.nodesblacklist:
    if not re.match('^\+?mb-[a-z]{3}[0-9]{3}(,mb-[a-z]{3}[0-9]{3})*$', opts.nodesblacklist):
        print("ERROR: Option --nodesblacklist has an invalid value." + \
              " It should be a comma separated list of node names. E.g.: mb-har025,mb-har026,mb-sab034")
        sys.exit(1)

#================================================================================
# 2. Parse the user's slurm_jobs file to find the batch script for each job id
#================================================================================

if opts.debug:
    print("Parsing user's slurm_jobs file.")

file = os.path.expanduser('~/.slurm_jobs')

if not opts.quiet:
    print("Will search the job ids in slurm_jobs file %s" % (file))

resubmitJobs = {}

notFoundJobIds = []

notFoundBatchScripts = []

for jobId in jobIds:
    jobInfo = {}
    found = False
    try:
        with open(file, 'r') as fd:
            for line in fd:
                if line.split()[1] == jobId:
                    batchScript = line.split()[2]
                    isJobArray = True if line.split()[3].split('jobarray=')[1] == 'T' else False
                    found = True
                    break
    except Exception as ex:
        msg = "WARNING: Failed to parse slurm_jobs file."
        msg += "\nError follows:"
        msg += "\n%s" % (str(ex))
        print(msg)
        sys.exit(1)
    if found:
        if not os.path.isfile(batchScript):
            notFoundBatchScripts.append(jobId)
            continue
        jobInfo['batch_script'] = batchScript
        jobInfo['is_job_array'] = isJobArray
        if isJobArray:
            jobInfo['jobs_to_resubmit'] = []
            for v in jobIdsFull:
                if v.split('_')[0] == jobId and len(v.split('_')) > 1:
                    if v.split('_')[1] not in jobInfo['jobs_to_resubmit']:
                        jobInfo['jobs_to_resubmit'].append(v.split('_')[1])
        partition, qos = None, None
        try:
            with open(batchScript, 'r') as fd:
                for line in fd:
                    if not partition:
                        m = re.match("^#SBATCH +(--partition(=| +)|-p +)([a-z,A-Z,0-9,\,]+)", line)
                        if m:
                            partition = m.group(3)
                    if not qos:
                        m = re.match("^#SBATCH +--qos(=| +)([a-z,A-Z,0-9,\,]+)", line)
                        if m:
                            qos = m.group(2)
                    if partition and qos:
                        break
        except IOError as ex:
            msg = "ERROR: Failed to read batch script file %s" % (batchScript)
            msg += "\nError follows:"
            msg += "\n%s" % (str(ex))
            print(msg)
            sys.exit(1)
        if not partition:
            print("ERROR: Did not find sbatch --partition/-p option in batch script %s" % (origBatchScript))
            sys.exit(1)
        if not qos:
            print("ERROR: Did not find sbatch --qos option in batch script %s" % (origBatchScript))
            sys.exit(1)
        jobInfo['partition'] = partition
        jobInfo['qos'] = qos
        resubmitJobs[jobId] = jobInfo
    else:
        notFoundJobIds.append(jobId)

if notFoundJobIds:
    print("WARNING: The following job ids will not be resubmitted, because they were not found in the slurm_jobs file: %s" % (notFoundJobIds))
    for jobId in notFoundJobIds:
        jobIds.remove(jobId)

if notFoundBatchScripts:
    print("WARNING: The following job ids will not be resubmitted, because their corresponding batch scripts" + \
          " were not found in the location reported by the slurm_jobs file: %s" % (notFoundBatchScripts))
    for jobId in notFoundBatchScripts:
        jobIds.remove(jobId)

if not resubmitJobs:
    print("Nothing to resubmit.")
    sys.exit(0)

if opts.debug:
    print("User's slurm_jobs file parsed.")

#================================================================================
# 3. Get failed job ids (resubmit only failed jobs)
#================================================================================

if opts.debug:
    print("Getting list of failed jobs.")


for jobId in copy.deepcopy(jobIds):
    failedJobs = []
    resubmitJobs[jobId]['nodes'] = []
    process = subprocess.Popen(['sacct', '-j', jobId, '-o', 'jobid,state,exitcode,nodelist', '-p'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    stdout, stderr = process.communicate()
    returncode = process.returncode
    if returncode != 0:
        print("ERROR: sacct for job %s failed." % (jobId))
        if stdout:
            print('Stdout:\n    %s' % (str(stdout).replace('\n', '\n    ')))
        if stderr:
            print('Stderr:\n    %s' % (str(stderr).replace('\n', '\n    ')))
        sys.exit(1)
    for line in stdout.split('\n')[1:-1]:
        if line.strip().split('|')[0].endswith('.batch'):
            continue
        jobIdNoStep = line.strip().split('|')[0].split('.')[0]
        jobState    = line.strip().split('|')[1]
        exitCode    = line.strip().split('|')[2].split(':')[0]
        nodes       = line.strip().split('|')[3]
        for state in resubmitJobStates:
            if jobState == state.split(':')[0]:
                if ':' not in state or exitCode == state.split(':')[1]:
                    failedJobs.append(jobIdNoStep)
                    resubmitJobs[jobId]['nodes'].append(nodes)
    failedJobs = list(set(failedJobs))
    resubmitJobs[jobId]['nodes'] = ','.join(resubmitJobs[jobId]['nodes'])
    if not failedJobs:
        if not opts.quiet:
            print("Job id %s has 0 jobs in a state for resubmission." % (jobId))
        jobIds.remove(jobId)
        del resubmitJobs[jobId]
        continue
    isJobArray = '_' in failedJobs[0]
    if isJobArray != resubmitJobs[jobId]['is_job_array']:
        print("ERROR: Not clear whether job id %s corresponds to a job array or not." % (jobId))
        sys.exit(1)
    if isJobArray:
        failedJobs = list(map(str, sorted(list(map(int, [j.split('_')[1] for j in failedJobs])))))
        resubmitJobs[jobId]['failed_jobs'] = failedJobs
        jobArrayTaskIds = list(set([l.strip().split('|')[0].split('.')[0].split('_')[1] for l in stdout.split('\n')[1:-1]]))
        resubmitJobs[jobId]['task_ids'] = jobArrayTaskIds

if not resubmitJobs:
    print("Nothing to resubmit.")
    sys.exit(1)

if opts.debug:
    print("List of failed jobs retrieved.")

#================================================================================
# 4. Create a new batch script (resubmit only failed jobs)
#================================================================================

if opts.debug:
    print("Creating new batch scripts for job arrays.")

for jobId in copy.deepcopy(jobIds):
    jobInfo = resubmitJobs[jobId]
    if jobInfo['is_job_array']:
        invalidTaskIds = []
        notFailedJobs = []
        if jobInfo['jobs_to_resubmit']:
            for j in copy.deepcopy(jobInfo['jobs_to_resubmit']):
                if j not in jobInfo['task_ids']:
                    invalidTaskIds.append(j)
                    jobInfo['jobs_to_resubmit'].remove(j)
                elif j not in jobInfo['failed_jobs']:
                    notFailedJobs.append(j)
                    jobInfo['jobs_to_resubmit'].remove(j)
        else:
            jobInfo['jobs_to_resubmit'] = copy.deepcopy(jobInfo['failed_jobs'])
        if invalidTaskIds and not opts.quiet:
            print("WARNING: Job array %s does not contain the following jobs, so they can not be resubmitted: %s" % (jobId, invalidTaskIds))
        if notFailedJobs and not opts.quiet:
            print("WARNING: The following jobs in job array %s are not in a failed state, so they won't be resubmitted: %s" % (jobId, notFailedJobs))
        if not jobInfo['jobs_to_resubmit']:
            jobIds.remove(jobId)
            del resubmitJobs[jobId]
            continue
    origBatchScript = jobInfo['batch_script']
    origBatchScriptBasename, origBatchScriptExt = os.path.splitext(origBatchScript)
    if re.match('^.*_res[1-9][0-9]*$', origBatchScriptBasename):
        resubmissionIndex = str(int(origBatchScriptBasename.rsplit('_res', 1)[1]) + 1)
        newBatchScript = origBatchScriptBasename.rsplit('_res', 1)[0] + '_res' + resubmissionIndex + origBatchScriptExt
    else:
        newBatchScript = origBatchScriptBasename + '_res1' + origBatchScriptExt
    newBatchScriptText = ""
    batchScripChanged = False
    found = [False]*3
    try:
        with open(origBatchScript, 'r') as fd:
            sbatch = [False, False]
            for line in fd:
                if re.match("^#SBATCH +", line):
                    sbatch[0] = True
                else:
                    sbatch[1] = sbatch[0]
                skipLine = False
                newLine = ""
                if jobInfo['is_job_array'] and re.match("^#SBATCH +(--array(=| +)|-a +)", line):
                    if sbatch[0] and not sbatch[1]:
                        skipLine = True
                        if not found[0]:
                            newLine = "#SBATCH --array=%s\n" % (','.join(jobInfo['jobs_to_resubmit']))
                        found[0] = True
                if opts.samenodes or opts.nodesblacklist or opts.nodeswhitelist:
                    blacklistedNodes = ""
                    if re.match("^#SBATCH +(--exclude(=| +)|-x +)", line):
                        if sbatch[0] and not sbatch[1]:
                            skipLine = True
                            if (opts.nodesblacklist and opts.nodesblacklist[0] == "+") or (opts.samenodes and not opts.nodeswhitelist):
                                blacklistedNodes = line[line.find('mb-'):]
                    if opts.samenodes or opts.nodeswhitelist:
                        if sbatch[1] and not found[1]:
                            excludeNodes = []
                            process = subprocess.Popen(['sinfo', '--Node', '--noheader', '--format=%n'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
                            stdout, stderr = process.communicate()
                            if process.returncode == 0 and stdout:
                                allNodes = stdout.rstrip('\n').split('\n')
                                for node in allNodes:
                                    if (opts.samenodes and node[:6] not in resubmitJobs[jobId]['nodes']) or (opts.nodeswhitelist and node not in opts.nodeswhitelist):
                                        excludeNodes.append(node)
                                if opts.samenodes and not opts.nodeswhitelist:
                                    if blacklistedNodes and set(blacklistedNodes.split(',')) != set(excludeNodes):
                                        if excludeNodes:
                                            newBatchScriptText += "#SBATCH --exclude=%s,%s\n" % (blacklistedNodes, ','.join(excludeNodes))
                                        else:
                                            newBatchScriptText += "#SBATCH --exclude=%s\n" % (blacklistedNodes)
                                        batchScriptChanged = True
                                    elif excludeNodes:
                                        newBatchScriptText += "#SBATCH --exclude=%s\n" % (','.join(excludeNodes))
                                        batchScriptChanged = True
                                elif excludeNodes:
                                    newBatchScriptText += "#SBATCH --exclude=%s\n" % (','.join(excludeNodes))
                                    batchScriptChanged = True
                            found[1] = True
                    if opts.nodesblacklist:
                        if sbatch[1] and not found[2]:
                            if blacklistedNodes:
                                newBatchScriptText += "#SBATCH --exclude=%s,%s\n" % (blacklistedNodes, opts.nodesblacklist)
                            else:
                                newBatchScriptText += "#SBATCH --exclude=%s\n" % (opts.nodesblacklist)
                            batchScriptChanged = True
                            found[2] = True
                if not skipLine:
                    newBatchScriptText += line
                if newLine:
                    newBatchScriptText += newLine
                if skipLine or newLine:
                    batchScriptChanged = True
    except IOError as ex:
        msg = "ERROR: Failed to read batch script file %s" % (origBatchScript)
        msg += "\nError follows:"
        msg += "\n%s" % (str(ex))
        print(msg)
        sys.exit(1)
    if jobInfo['is_job_array'] and not found[0]:
        print("ERROR: Did not find sbatch --array option in batch script %s" % (origBatchScript))
        sys.exit(1)
    resubmitJobs[jobId]['new_batch_script'] = None
    if batchScriptChanged:
        try:
            with open(newBatchScript, 'w') as fd:
                fd.write(newBatchScriptText)
        except IOError as ex:
            msg = "ERROR: Failed to write new batch script file %s" % (newBatchScript)
            msg += "\nError follows:"
            msg += "\n%s" % (str(ex))
            print(msg)
            sys.exit(1)
        resubmitJobs[jobId]['new_batch_script'] = newBatchScript

if not resubmitJobs:
    print("Nothing to resubmit.")
    sys.exit(0)

if opts.debug:
    print("Batch script(s) created.")

#================================================================================
# 5. Print a summary of what will be submitted
#================================================================================

if opts.debug:
    print("Creating summary text.")

summary = "================================================================================"
summary += "\n=============  Below there is a summary of what would be submitted  ============"
summary += "\n================================================================================"
for jobId, jobInfo in CP3SlurmUtils.Py2Py3Helpers.iteritems(resubmitJobs):
    if jobInfo['is_job_array']:
        summary += "\nFor job array %s:" % (jobId)
        summary += " will submit a new job array consisting of %d jobs." % (len(jobInfo['jobs_to_resubmit']))
        summary += "\n    Job array indices: %s" % (','.join(jobInfo['jobs_to_resubmit']))
    else:
        summary += "\nFor job %s:" % (jobId)
        summary += " will submit a new job."
    if jobInfo['new_batch_script']:
        summary += "\n    Original batch script (FYI only):"
        summary += "\n        %s" % (jobInfo['batch_script'])
        summary += "\n    New batch script:"
        summary += "\n        %s" % (jobInfo['new_batch_script'])
    else:
        summary += "\n    Batch script:"
        summary += "\n        %s" % (jobInfo['batch_script'])
summary += "\n================================================================================"
summary += "\n===============================  end of summary  ==============================="
summary += "\n================================================================================"

if opts.debug:
    print("Summary text created.")

print(summary)

#================================================================================
# 6. Submit the jobs
#================================================================================

if opts.submit:
    if not opts.yes:
       timeout = 5
       print("You provided the -s option, which instructs to submit the jobs.")
       if not opts.quiet:
           print("Please read first the summary given above, then answer whether you really want to submit the jobs.")
           print("N.B.: You can always submit the jobs (i.e. batch scripts) yourself with the slurm 'sbatch' command.")
       question = "Are you sure you want to submit the jobs ? (you have %d minutes to answer)" % (timeout)
       signal.signal(signal.SIGALRM, questionTimeoutHandler)
       signal.alarm(timeout*60)
       try:
           answer = interactiveYesNoQuestion(question, default='no')
       except TimeoutException as ex:
           print(ex)
           answer = False
       signal.alarm(0)
       if not answer:
           print("Jobs were not submitted.")
           sys.exit(0)
    for jobId, jobInfo in CP3SlurmUtils.Py2Py3Helpers.iteritems(resubmitJobs):
        batchScript = jobInfo['new_batch_script'] if jobInfo['new_batch_script'] else jobInfo['batch_script']
        if jobInfo['is_job_array']:
            numJobs = len(jobInfo['jobs_to_resubmit'])
            print("Submitting new job array consisting of %d jobs as a resubmission of job id %s." % (numJobs, jobId))
            ret = submit(batchScript, partition=jobInfo['partition'], qos=jobInfo['qos'], jobArray=True)
            if ret == 2:
                sys.exit(1)
        else:
            print("Submitting new job as resubmission of job id %s." % (jobId))
            ret = submit(batchScript, partition=jobInfo['partition'], qos=jobInfo['qos'], jobArray=False)
            if ret == 2:
                sys.exit(1)
else:
   print("You have not provided the -s option, so jobs were not submitted.")
   if not opts.quiet:
       print("You can always submit the jobs (i.e. batch scripts) yourself with the slurm 'sbatch' command.")

sys.exit(0)
