#!/usr/bin/env python3

from logzero import logger as log

import subprocess
import argparse
import logzero
import tarfile
import signal
import tqdm
import glob
import os

#----------------------------
class data:
    job_name = None
    log_level= None
    sandbox  = None
    timeout  = None
#----------------------------
def get_args():
    parser = argparse.ArgumentParser(description='Used to download outputs of toy fits from the grid')
    parser.add_argument('-n', '--name' , type=str, help='Name of job', required=True)
    parser.add_argument('-s', '--sand' , type=str, help='Path to sandbox directory', default=os.getcwd())
    parser.add_argument('-t', '--time' , type=int, help='Timeout for dirac commands', default=1000)
    parser.add_argument('-l', '--logl' , type=int, help='Log level', choices=[logzero.DEBUG, logzero.INFO, logzero.WARNING], default=logzero.INFO)
    args = parser.parse_args()

    data.job_name = args.name
    data.sandbox  = args.sand
    data.log_level= args.logl
    data.timeout  = args.time
#-------------------------------------------------------
get_args()
#-------------------------------------------------------
def timeout_handler(signum, frame):
    raise TimeoutError("Timeout occurred.")
#------------------------------------------------------------------
def add_timeout(seconds=10):
    '''  
    This is a function decorator to timeout functions

    Example:

    @utils_noroot.add_timeout(seconds=5)
    val = fun()

    if the function takes longer than 5 seconds, it will be interrupted
    anf val=None
    '''
    def decorator_function(original_function):
        def wrapper_function(*args, **kwargs):
            signal.signal(signal.SIGALRM, timeout_handler)
            signal.alarm(seconds)

            result = None 
            try: 
                result = original_function(*args, **kwargs)
            except TimeoutError:
                log.warning(f'Timeout: > {seconds} sec.')
                raise
            finally:
                signal.alarm(0)
     
                return result
     
        return wrapper_function
    return decorator_function
#----------------------------
@add_timeout(seconds=data.timeout)
def run_command(cmd, options=None, raise_on_fail=True):
    if not isinstance(options, list):
        log.error(f'Invalid options argument: {options}')
        raise ValueError

    log.debug('-' * 30)
    log.debug('-' * 30)
    log.debug(f'{cmd:<10}{str(options):<50}')
    log.debug('-' * 30)
    log.debug('-' * 30)

    with open('/tmp/rk_ext_output.log', 'w') as ofile:
        stat = subprocess.run([cmd] + options, stdout=ofile, stderr=ofile)

    if stat.returncode != 0:
        log.error(f'Process returned exit status: {stat.returncode}')
        if raise_on_fail:
            raise
#----------------------------
def get_ids():
    ids_path = f'sandbox_{data.job_name}/jobids.out'

    if not os.path.isfile(ids_path):
        log.error(f'File not found: {ids_path}')
        raise FileNotFoundError

    with open(ids_path) as ifile:
        l_ids = ifile.read().splitlines()

    return l_ids
#----------------------------
def get_lfns():
    l_file = glob.glob('*.lfns')
    if len(l_file) != 1:
        log.error('Did not find one and only one lfns file')
        raise FileNotFoundError

    with open(l_file[0]) as ifile:
        l_lfn = ifile.read().splitlines()

    os.remove(l_file[0])

    return l_lfn
#----------------------------
def skip_download(file_name):
    if not os.path.isfile(file_name):
        return False 

    log.debug('File found')

    try:
        tf=tarfile.open(file_name)
        tf.close()
    except tarfile.ReadError:
        log.warning('File cannot be opened, downloading again')
        return False

    return True 
#----------------------------
def download(jobid):
    file_name= 'result_jsn.tar.gz'
    if skip_download(file_name):
        return

    jobdir = jobid[:-3]

    grid_path=f'/lhcb/user/a/acampove/{jobdir}/{jobid}'

    run_command('dirac-dms-user-lfns', ['-w', file_name, '-b', grid_path])

    l_lfn = get_lfns()
    for lfn in l_lfn:
        run_command('dirac-dms-get-file', [lfn])
#----------------------------
def main():
    l_jobid = get_ids()
    for jobid in tqdm.tqdm(l_jobid, ascii=' -'):
        os.makedirs(f'{data.sandbox}/output_{data.job_name}/{jobid}', exist_ok=True)
        os.chdir(f'{data.sandbox}/output_{data.job_name}/{jobid}')
        download(jobid)
        os.chdir('..')
#----------------------------
if __name__ == '__main__':
    log.setLevel(data.log_level)
    main()

