#!/usr/bin/env python3

from DIRAC.Interfaces.API.Dirac import Dirac
from DIRAC.Interfaces.API.Job   import Job
from DIRAC                      import initialize
from DIRAC                      import gLogger

import os
import glob
import argparse

from logzero import logger as log
from tqdm    import trange

#---------------------------------------
class data:
    njobs = None
    nfits = None
    mode  = None
    name  = None
    vfix  = None
    mods  = None

    ext_dir = os.environ['EXTDIR'] 
    snd_dir = None 
#---------------------------------------
def get_banned_sites():
    l_site = [
            'LCG.PIC.es',
            'LCG.Liverpool.uk',
            'LCG.NCBJ-CIS.pl',
            'LCG.UKI-LT2-RHUL.uk',
            'LCG.NCBJ.pl',
            'LCG.ECDF.uk',
            'LCG.Manchester.uk', 
            'LCG.NIPNE-07.ro', 
            'LCG.Krakow.pl', 
            'LCG.PNPI.ru', 
            'LCG.MIT.us', 
            'LCG.UKI-LT2-IC-HEP.uk', 
            'LCG.USC.es', 
            'LCG.JINR.ru'
            ]

    return l_site
#---------------------------------------
def get_tarballs():
    tar_wc = f'{data.ext_dir}/*.tar.gz'

    l_tarball = glob.glob(tar_wc)
    if len(l_tarball) == 0:
        log.error(f'Cannot find any tarball in: {tar_wc}')
        raise

    log.info('Tarballs:')
    for tarball in l_tarball:
        log.info(f'    {tarball}')

    return l_tarball
#---------------------------------------
def get_job(jobid):
    seeds_file = f'{data.snd_dir}/seeds/{jobid}.sd'
    if not os.path.isfile(seeds_file):
        log.error(f'Cannot find: {seeds_file}')
        raise FileNotFoundError

    j = Job()
    j.setCPUTime(36000)
    j.setBannedSites(get_banned_sites())

    vfix = ' '.join(data.vfix)
    mods = ' '.join(data.mods)
    j.setExecutable('rxe_run_toys', arguments=f'"all_TOS all_TIS" "{vfix}" "{mods}"')
    j.setInputSandbox([seeds_file, 'rxe_toys'] + get_tarballs())
    j.setOutputData(['results/result_pkl.tar.gz', 'results/result_jsn.tar.gz'])
    j.setName(f'rxext_{data.name}_{jobid:03}')

    return j
#---------------------------------------
def make_seeds():
    log.info(f'Making seeds')
    os.makedirs(f'{data.snd_dir}/seeds', exist_ok=True)
    for ijob in range(data.njobs):
        log.debug(f'Writting {data.snd_dir}/seeds/{ijob}.sd')
        ofile = open(f'{data.snd_dir}/seeds/{ijob}.sd', 'w')
        for ifit in range(1000 * ijob , data.nfits + 1000 * ijob):
            ofile.write(f'{ifit}\n')
        ofile.close()
#---------------------------------------
def get_args():
    parser = argparse.ArgumentParser(description='Used to send toy fit jobs to the grid')
    parser.add_argument('-n', '--name' , type =str, help='Job name', required=True)
    parser.add_argument('-j', '--njobs', type =int, help='Number of grid jobs', required=True)
    parser.add_argument('-f', '--nfits', type =int, help='Number of fits per job', required=True)
    parser.add_argument('-m', '--mode' , type =str, help='Run locally or in the grid', choices=['local', 'wms'], required=True)
    parser.add_argument('-v', '--vfix' , nargs='+', help='Variables to fix if assessing systematics', default='none')
    parser.add_argument('-p', '--mods' , nargs='+', help='Variations in the fitting model', default='none')
    args = parser.parse_args()

    data.snd_dir = f'{os.getcwd()}/sandbox_{args.name}'
    data.name    = args.name
    data.njobs   = args.njobs
    data.nfits   = args.nfits
    data.mode    = args.mode
    data.vfix    = args.vfix
    data.mods    = args.mods
#---------------------------------------
def main():
    os.makedirs(data.snd_dir, exist_ok=False)

    gLogger.setLevel('warning')
    initialize()
    dirac = Dirac()

    make_seeds()

    l_jobid = []
    for jobid in trange(data.njobs):
        job    = get_job(jobid)
        d_info = dirac.submitJob(job, mode=data.mode)

        try:
            jobid = d_info['JobID']
        except:
            jobid = -1

        l_jobid.append(jobid)

    with open(f'{data.snd_dir}/jobids.out', 'w') as ofile:
        for jobid in l_jobid:
            ofile.write(f'{jobid}\n')
#---------------------------------------
if __name__ == '__main__':
    get_args()
    main()

