#!/usr/bin/env python3

from DIRAC.Interfaces.API.Dirac import Dirac
from DIRAC.Interfaces.API.Job   import Job
from DIRAC                      import initialize
from DIRAC                      import gLogger

import os
import glob
import argparse

from logzero import logger as log
from tqdm    import trange

#---------------------------------------
class data:
    njobs   = None
    nfits   = None
    name    = None
    vers    = None
    snd_dir = None 
#---------------------------------------
def get_banned_sites():
    l_site = [
            'LCG.PIC.es',
            'LCG.Liverpool.uk',
            'LCG.NCBJ-CIS.pl',
            'LCG.UKI-LT2-RHUL.uk',
            'LCG.NCBJ.pl',
            'LCG.ECDF.uk',
            'LCG.Manchester.uk', 
            'LCG.NIPNE-07.ro', 
            'LCG.Krakow.pl', 
            'LCG.PNPI.ru', 
            'LCG.MIT.us', 
            'LCG.UKI-LT2-IC-HEP.uk', 
            'LCG.USC.es', 
            'LCG.JINR.ru'
            ]

    return l_site
#---------------------------------------
def get_job(jobid):
    seeds_file = f'{data.snd_dir}/seeds/{jobid}.sd'
    if not os.path.isfile(seeds_file):
        log.error(f'Cannot find: {seeds_file}')
        raise FileNotFoundError

    j = Job()
    j.setCPUTime(36000)
    j.setBannedSites(get_banned_sites())
    j.setExecutable('rxe_run_toys', arguments=data.vers)
    j.setInputSandbox([seeds_file, 'rxe_toys'])
    j.setOutputData(['results/result_pkl.tar.gz', 'results/result_jsn.tar.gz'])
    j.setName(f'rxext_{data.name}_{jobid:03}')

    return j
#---------------------------------------
def make_seeds():
    log.info(f'Making seeds')
    os.makedirs(f'{data.snd_dir}/seeds', exist_ok=True)
    for ijob in range(data.njobs):
        log.debug(f'Writting {data.snd_dir}/seeds/{ijob}.sd')
        ofile = open(f'{data.snd_dir}/seeds/{ijob}.sd', 'w')
        for ifit in range(1000 * ijob , data.nfits + 1000 * ijob):
            ofile.write(f'{ifit}\n')
        ofile.close()
#---------------------------------------
def get_args():
    parser = argparse.ArgumentParser(description='Used to send toy fit jobs to the grid')
    parser.add_argument('-n', '--name' , type =str, help='Job name'              , required=True)
    parser.add_argument('-j', '--njobs', type =int, help='Number of grid jobs'   , required=True)
    parser.add_argument('-f', '--nfits', type =int, help='Number of fits per job', required=True)
    parser.add_argument('-v', '--vers' , type =str, help='Version of config file', required=True)
    args = parser.parse_args()

    data.snd_dir = f'{os.getcwd()}/sandbox_{args.name}'
    data.name    = args.name
    data.njobs   = args.njobs
    data.nfits   = args.nfits
    data.vers    = args.vers
#---------------------------------------
def main():
    os.makedirs(data.snd_dir, exist_ok=False)

    gLogger.setLevel('warning')
    initialize()
    dirac = Dirac()

    make_seeds()

    l_jobid = []
    for jobid in trange(data.njobs):
        job    = get_job(jobid)
        d_info = dirac.submitJob(job, mode=data.mode)

        try:
            jobid = d_info['JobID']
        except:
            jobid = -1

        l_jobid.append(jobid)

    with open(f'{data.snd_dir}/jobids.out', 'w') as ofile:
        for jobid in l_jobid:
            ofile.write(f'{jobid}\n')
#---------------------------------------
if __name__ == '__main__':
    get_args()
    main()

