#!/usr/bin/env python3
import os,sys
# CODE_PATH = '/raid/taoyang/research/research_everyday/projects/fairness/FairRec_www_2020/slurm/'
import json
import argparse
#import argcomplete
import psutil
import time
import glob
import random
from shutil import move
from progressbar import progressbar
parser = argparse.ArgumentParser(description='Baseline Run')
parser.add_argument('--CODE_PATH', type=str, default=None,required=True,
                    help='Directory of file.')
parser.add_argument('--Cmd_file', type=str, default="main.py",
                    help='exact python file we need to execute')
parser.add_argument('--JSON_PATH', type=str, default=None,required=True,
                    help='Directory of Json setting folder.')
parser.add_argument('--json2args', action='store_true',
                    help='list the param in json file as script param')
parser.add_argument('--python_ver',type=str, default="pt_ranking",
                    help='specify python version')
# parser.add_argument('--test_only',action='store_true',
#                     help='if not give, will exclude json file name containing test. if give, will only use json file containing test')
parser.add_argument('--plain_script',action='store_true',
                    help='if not give, will use slurm')
parser.add_argument('--rerun',action='store_true',
                    help='if not give, will not rerun')
parser.add_argument('--seefinished',action='store_true',
                    help='check the finished')
parser.add_argument('--recover',action='store_true',
                    help='if not give, will not recover')
parser.add_argument('--show_backup',action='store_true',
                    help='if not give, will not show_backup')
parser.add_argument('--jobs_limit',type=int, default=10,
                    help='number of jobs when executing in plain script mode')
parser.add_argument('--ntasks',type=int, default=4,
                    help='number of tasks when executing in slurm mode')
parser.add_argument('--black_list',type=str,
                    help='specify the segment of string where task name contain we will not execute. Use + to specify more')
parser.add_argument('--white_list',type=str,
                    help='specify the segment of string where task name contain we will only execute. Use + to specify more')
parser.add_argument('--only_print',action='store_true',
                    help='specify to see if we only print the cmd')
parser.add_argument('--only_unfinished',action='store_true',
                    help='specify to see if we only run only_unfinished')
parser.add_argument('--initial_n_cmds',type=int, default=0,
                    help='number of tasks when executing in slurm mode')
parser.add_argument('--secs_each_sub',type=float, default=1.0,
                    help='secs between each submission')
parser.add_argument('--memory_usage',type=int, default=70,
                    help='memory usage, [0,100]')
#argcomplete.autocomplete(parser)
args = parser.parse_args()

memory_usage=args.memory_usage
rerun=args.rerun
seefinished=args.seefinished
recover=args.recover
only_unfinished=args.only_unfinished
show_backup=args.show_backup
BLACK_LIST=args.black_list.split("+") if args.black_list else []
WHITE_LIST=args.white_list.split("+")if args.white_list else []
CODE_PATH =args.CODE_PATH
SETTING_JSON_PATH=args.JSON_PATH
json_to_args=args.json2args
python_ver=args.python_ver
only_print=args.only_print
# test_only=args.test_only
plain_script=args.plain_script
jobs_limit=args.jobs_limit
ntasks=args.ntasks
Cmd_file=args.Cmd_file
secs_each_sub=args.secs_each_sub
# DATA_PATH = '/raid/taoyang/research/datasets/istella-s-letor/all_with_my_logprerprocess_trial2/tmp_datalog_preprocess/tmp_data_toy/'#sys.argv[2]
# OUTPUT_PATH = '/home/taoyang/research/research_everyday/projects/ULTRA_paper/result_log/'#sys.argv[3]
# LOG_PATH = '/home/taoyang/research/research_everyday/projects/ULTRA_paper/result_log/'
# REPEAT_TIMES = 2
python="~/miniconda3/envs/"+python_ver+"/bin/python "
job_count=1
def submit_one_job(prefix_path, json_file):
    job_name = '_'.join([prefix_path, json_file]).replace('/','').replace('.json','')
    
    #prepare directory
    json_file_name = json_file.replace('.json','')
    cur_json_path = os.path.join(SETTING_JSON_PATH, prefix_path[1:], json_file)
    cur_log_path = os.path.join(SETTING_JSON_PATH, prefix_path[1:], json_file_name)
    cur_output_path = os.path.join(SETTING_JSON_PATH, prefix_path[1:], json_file_name)
#     if test_only:
#         if "test" not in cur_json_path:
#             return 
#     else:
#         if "test"  in cur_json_path:
#             return         
    if rerun:
        cur_result_dir=os.path.join(SETTING_JSON_PATH, prefix_path[1:])
        paths=glob.glob(cur_result_dir+"/*")
        for path in paths:
            if  path.endswith("jjson"):
                print(path)
                move(path, path+"bk")
        return
    if seefinished:
        cur_result_dir=os.path.join(SETTING_JSON_PATH, prefix_path[1:])
        paths=glob.glob(cur_result_dir+"/*")
        for path in paths:
            if  path.endswith("jjson"):
                print(path)
        return        
    if recover:
        cur_result_dir=os.path.join(SETTING_JSON_PATH, prefix_path[1:])
        paths=glob.glob(cur_result_dir+"/*")
        for path in paths:
            if  path.endswith("jjsonbk") :
                print(path)
                if not os.path.exists(path[:-2]):
                    os.rename(path, path[:-2])
        return
    if show_backup:
        cur_result_dir=os.path.join(SETTING_JSON_PATH, prefix_path[1:])
        paths=glob.glob(cur_result_dir+"/*")
        for path in paths:
            if  path.endswith("jjsonbk") :
                print(path)
        return    

    if not os.path.exists(cur_log_path):
        os.makedirs(cur_log_path)
    if not os.path.exists(cur_output_path):
        os.makedirs(cur_output_path)
    job_path=os.path.join(cur_output_path,job_name)
    fout = open(job_path+'.sh', 'w')
    # set slurm parameters
    fout.write('#!/bin/bash\n')
    fout.write('#\n')
    fout.write('#SBATCH --job-name='+job_name+'\n')
#     fout.write('#SBATCH --partition=titan-long    # Partition to submit to \n')
    fout.write('#SBATCH --partition=titan-giant    # Partition to submit to \n')
#     fout.write('#SBATCH --partition=debug    # Partition to submit to \n')
    fout.write('#SBATCH --gres=gpu:1\n')
    fout.write('#SBATCH --ntasks='+str(ntasks)+'\n')
    fout.write('#SBATCH --mem=12000    # Memory in MB per cpu allocated\n\n')
 #   cd_cmd="cd "+CODE_PATH+"\n"
 #   fout.write(cd_cmd)
    if json_to_args:
        with open(cur_json_path,"r") as modelfile:
                model=json.load(modelfile)   
        script=""
        for (name,value) in model.items():
            script+=" --"+name
            if value!="":
                script+="="+str(value)
        cur_json_path=script
#    print(cur_json_path)
    model_command = ' '.join([
            python,
            CODE_PATH+"/"+Cmd_file,
            ' ' + cur_json_path
    ])
    fout.write(model_command+"\n")
    fout.write('exit\n')
    fout.close()
    
    command = 'sbatch -e '+cur_log_path+job_name+'.e -o '+cur_log_path+job_name+'.o'
    command += ' ' + job_path+'.sh'
#     print(command)
    flag_run=True
    for pass_name in  WHITE_LIST:
        flag_run=flag_run and  pass_name in job_name
    for skip_name in  BLACK_LIST:
        flag_run=flag_run and  skip_name not in job_name
    if not flag_run:
        return
    if plain_script:
        command = 'bash'+' ' + job_path+'.sh'
        command +=' > '+job_path+'.o'+" 2>"+job_path+'.e'
        global job_count
        # if job_count%jobs_limit!=0:
        #     command+="&"
        job_count+=1
        print(command)
        if not only_print:
            command+="&"
            os.system(command)
    else:
        print(command)
        if not only_print:
            os.system(command)

def check_path(paths):
    paths_new=[]
    for path in paths:
        if not glob.glob(path+"/*.jjson"):
            paths_new.append(path)
    return paths_new

# Read json file from json path
json_file_list = []
def get_json_file_from_path(prefix_path):
    for f in os.listdir(SETTING_JSON_PATH + prefix_path):
        
        if f.endswith('.json'):
            if "result.jjson" in os.listdir(SETTING_JSON_PATH + prefix_path) and only_unfinished:
                continue
            flag_run=True
            for pass_name in  WHITE_LIST:
                flag_run=flag_run and  pass_name in prefix_path+f
            for skip_name in  BLACK_LIST:
                flag_run=flag_run and  skip_name not in prefix_path+f
            if flag_run:
                json_file_list.append((prefix_path,f))
        elif os.path.isdir(SETTING_JSON_PATH + prefix_path + '/' + f):
            get_json_file_from_path(prefix_path + '/' + f)
get_json_file_from_path('')
random.seed(0)
random.shuffle(json_file_list)
if len(json_file_list)>1:
    json_file_list=sorted(json_file_list,key=lambda x: x[0][-1])
if not plain_script:
    for (prefix_path, json_file) in progressbar(json_file_list):
   #          time.sleep(secs_each_sub)
    #     for i in range(REPEAT_TIMES):
            submit_one_job(prefix_path, json_file)
else:
    result_list_path=[]
    inital=args.initial_n_cmds
    ind=0
    for (prefix_path, json_file) in progressbar(json_file_list):
#     for i in range(REPEAT_TIMES):
        path_cur= '/'.join([SETTING_JSON_PATH, prefix_path])
        
        while psutil.virtual_memory().percent>memory_usage:
            time.sleep(secs_each_sub)
        # if psutil.virtual_memory().percent<80:
        submit_one_job(prefix_path, json_file)
        result_list_path.append(path_cur)
        while len(result_list_path)>=jobs_limit and not only_print and not rerun:
#             print(result_list_path,"result_list_path")
            result_list_path=check_path(result_list_path)
            time.sleep(secs_each_sub)
        if ind>=inital and not only_print:
            time.sleep(secs_each_sub)
        ind+=1
        #print(result_list_path,"result_list_path")