#!/usr/bin/python
# The script was taken from the intel web-site and modified
# https://software.intel.com/sites/default/files/0f/a7/mkl_benchmark.py
import os
import sys
import timeit
import numpy as np
import GTMcore as GTM
from PDielec.DielectricFunction import DielectricFunction
from PDielec.Constants import wavenumber, speed_light_si
from functools import partial
import time
import psutil

# Setting the following environment variable in the shell executing the script allows
# you limit the maximal number threads used for computation
THREADS_LIMIT_ENV = 'MKL_NUM_THREADS'

def get_pool(ncpus, threading, initializer=None, initargs=None, debugger=None ):
    global variable_type
    global parallel_type
    if parallel_type == 'multiprocess':
        pool = get_multiprocess_pool(ncpus, threading, initializer, initargs, debugger)
    elif parallel_type == 'multiprocessing':
        pool = get_multiprocessing_pool(ncpus, threading, initializer, initargs, debugger)
    elif parallel_type == 'pathos':
        pool = get_pathos_pool(ncpus, threading, initializer, initargs, debugger)
    elif parallel_type == 'ray':
        pool = get_ray_pool(ncpus, threading, initializer, initargs, debugger)
    else:
        print('Error unkown parallel type ',parallel_type)
    return pool


def get_pathos_pool(ncpus, threading, initializer=None, initargs=None, debugger=None ):
     """Return a pool of processors given the number of cpus and whether threading is requested"""
     if threading:
         from pathos.threading import ThreadPool
         if initargs is None:
             pool = ThreadPool(ncpus, initializer=initializer)
         else:
             pool = ThreadPool(ncpus, initializer=initializer, initargs=initargs)
     else:
         from pathos.pools import ProcessPool
         if initargs is None:
             pool = ProcessPool(ncpus, initializer=initializer)
         else:
             pool = ProcessPool(ncpus, initializer=initializer, initargs=initargs )
     return pool

def get_ray_pool(ncpus, threading, initializer=None, initargs=None, debugger=None ):
     """Return a pool of processors given the number of cpus and whether threading is requested"""
     import ray
     if threading:
         from ray.util.multiprocessing.dummy import Pool
         if initargs is None:
             pool = Pool(ncpus, initializer=initializer)
         else:
             pool = Pool(ncpus, initializer=initializer, initargs=initargs)
     else:
         from ray.util.multiprocessing import Pool
         if initargs is None:
             pool = Pool(ncpus, initializer=initializer)
         else:
             pool = Pool(ncpus, initializer=initializer, initargs=initargs)
     return pool

def get_multiprocessing_pool(ncpus, threading, initializer=None, initargs=None, debugger=None ):
     """Return a pool of processors given the number of cpus and whether threading is requested"""
     from multiprocessing import set_start_method
     set_start_method('spawn',force=True)
     if threading:
         from multiprocessing.dummy import Pool
         if initargs is None:
             pool = Pool(ncpus, initializer=initializer)
         else:
             pool = Pool(ncpus, initializer=initializer, initargs=initargs)
     else:
         from multiprocessing import Pool
         if initargs is None:
             pool = Pool(ncpus, initializer=initializer)
         else:
             pool = Pool(ncpus, initializer=initializer, initargs=initargs )
     return pool

def get_multiprocess_pool(ncpus, threading, initializer=None, initargs=None, debugger=None ):
     """Return a pool of processors given the number of cpus and whether threading is requested"""
     if threading:
         from multiprocess.dummy import Pool
         if initargs is None:
             pool = Pool(ncpus, initializer=initializer)
         else:
             pool = Pool(ncpus, initializer=initializer, initargs=initargs)
     else:
         from multiprocess import Pool
         if initargs is None:
             pool = Pool(ncpus, initializer=initializer)
         else:
             pool = Pool(ncpus, initializer=initializer, initargs=initargs)
     return pool



def init_local(function, superstrateDielectricFunction, substrateDielectricFunction, crystalPermittivityFunction, superstrateDepth, substrateDepth, crystalDepth, mode, theta, phi, psi, angleOfIncidence):
    # Initialiser the workers in the pool
    print('init_local, function id',id(function))
    function.superstrateDielectricFunction = superstrateDielectricFunction
    function.substrateDielectricFunction = substrateDielectricFunction
    function.crystalPermittivityFunction = crystalPermittivityFunction
    function.superstrateDepth = superstrateDepth
    function.substrateDepth = substrateDepth
    function.crystalDepth = crystalDepth
    function.mode = mode
    function.theta = theta
    function.phi = phi
    function.psi = psi
    function.angleOfIncidence = angleOfIncidence
    return

def set_no_of_threads_on_worker(nthreads):
    '''Set default number of threads'''
    if THREADS_LIMIT_ENV in os.environ:
        print("Maximum number of threads used for computation is : %s" % os.environ[THREADS_LIMIT_ENV])
    os.environ['MKL_NUM_THREADS'] = str(nthreads)

def set_affinity_on_worker():
    '''When a new worker process is created, the affinity is set to all CPUs'''
    #JK print('I'm the process %d, setting affinity to all CPUs.' % os.getpid())
    #JK Commented out for the time being
    #JK os.system('taskset -p 0xff %d > /dev/null' % os.getpid())

def solve_single_crystal_equations_local(v):
    """ This is a parallel call to the single crystal equation solver,
    system is a GTM system"""
    # print('solve_single_crystal_equations_local ',id(solve_single_crystal_equations_local))
    superstrateDielectricFunction = solve_single_crystal_equations_local.superstrateDielectricFunction
    substrateDielectricFunction   = solve_single_crystal_equations_local.substrateDielectricFunction
    crystalPermittivityFunction   = solve_single_crystal_equations_local.crystalPermittivityFunction
    superstrateDepth              = solve_single_crystal_equations_local.superstrateDepth
    substrateDepth                = solve_single_crystal_equations_local.substrateDepth
    crystalDepth                  = solve_single_crystal_equations_local.crystalDepth
    mode                          = solve_single_crystal_equations_local.mode
    theta                         = solve_single_crystal_equations_local.theta
    phi                           = solve_single_crystal_equations_local.phi
    psi                           = solve_single_crystal_equations_local.psi
    angleOfIncidence              = solve_single_crystal_equations_local.angleOfIncidence
    # Create 3 layers, thickness is converted from microns to metres
    superstrate      = GTM.Layer(thickness=superstrateDepth*1e-6,epsilon1=superstrateDielectricFunction)
    substrate        = GTM.Layer(thickness=substrateDepth*1e-6,  epsilon1=substrateDielectricFunction)
    crystal          = GTM.Layer(thickness=crystalDepth*1e-9,    epsilon=crystalPermittivityFunction)
    # Creat the system with the layers 
    if mode == 'Thick slab':
        system = GTM.System(substrate=crystal, superstrate=superstrate, layers=[])
    elif mode == 'Coherent thin film':
        system = GTM.System(substrate=substrate, superstrate=superstrate, layers=[crystal])
    else:
        system = GTM.System(substrate=substrate, superstrate=superstrate, layers=[crystal])
    # Rotate the dielectric constants to the laboratory frame
    system.substrate.set_euler(theta, phi, psi)
    system.superstrate.set_euler(theta, phi, psi)
    for layer in system.layers:
        layer.set_euler(theta, phi, psi)
    # 
    # convert cm-1 to frequency
    #
    freq = v * speed_light_si * 1e2
    system.initialize_sys(freq)
    zeta_sys = np.sin(angleOfIncidence)*np.sqrt(system.superstrate.epsilon[0,0])
    Sys_Gamma = system.calculate_GammaStar(freq, zeta_sys)
    r, R, t, T = system.calculate_r_t(zeta_sys)
    if len(system.layers) > 0:
        epsilon = system.layers[0].epsilon
    else:
        epsilon = system.substrate.epsilon
    return v,r,R,t,T,epsilon

def solve_single_crystal_equations_global(v):
    """ This is a parallel call to the single crystal equation solver,
    system is a GTM system"""
    global superstrateDielectricFunction
    global substrateDielectricFunction
    global crystalPermittivityFunction
    global superstrateDepth
    global substrateDepth
    global crystalDepth
    global mode
    global theta
    global phi
    global psi
    global angleOfIncidence
    # Create 3 layers, thickness is converted from microns to metres
    superstrate      = GTM.Layer(thickness=superstrateDepth*1e-6,epsilon1=superstrateDielectricFunction)
    substrate        = GTM.Layer(thickness=substrateDepth*1e-6,  epsilon1=substrateDielectricFunction)
    crystal          = GTM.Layer(thickness=crystalDepth*1e-9,    epsilon=crystalPermittivityFunction)
    # Creat the system with the layers 
    if mode == 'Thick slab':
        system = GTM.System(substrate=crystal, superstrate=superstrate, layers=[])
    elif mode == 'Coherent thin film':
        system = GTM.System(substrate=substrate, superstrate=superstrate, layers=[crystal])
    else:
        system = GTM.System(substrate=substrate, superstrate=superstrate, layers=[crystal])
    # Rotate the dielectric constants to the laboratory frame
    system.substrate.set_euler(theta, phi, psi)
    system.superstrate.set_euler(theta, phi, psi)
    for layer in system.layers:
        layer.set_euler(theta, phi, psi)
    # 
    # convert cm-1 to frequency
    #
    freq = v * speed_light_si * 1e2
    system.initialize_sys(freq)
    zeta_sys = np.sin(angleOfIncidence)*np.sqrt(system.superstrate.epsilon[0,0])
    Sys_Gamma = system.calculate_GammaStar(freq, zeta_sys)
    r, R, t, T = system.calculate_r_t(zeta_sys)
    if len(system.layers) > 0:
        epsilon = system.layers[0].epsilon
    else:
        epsilon = system.substrate.epsilon
    return v,r,R,t,T,epsilon

def solve_single_crystal_equations_partial( 
    superstrateDielectricFunction,
    substrateDielectricFunction,
    crystalPermittivityFunction,
    superstrateDepth,
    substrateDepth,
    crystalDepth,
    mode,
    theta,
    phi,
    psi,
    angleOfIncidence,
    v):
    """ This is a parallel call to the single crystal equation solver,
    system is a GTM system"""
    # Create 3 layers, thickness is converted from microns to metres
    superstrate      = GTM.Layer(thickness=superstrateDepth*1e-6,epsilon1=superstrateDielectricFunction)
    substrate        = GTM.Layer(thickness=substrateDepth*1e-6,  epsilon1=substrateDielectricFunction)
    crystal          = GTM.Layer(thickness=crystalDepth*1e-9,    epsilon=crystalPermittivityFunction)
    # Creat the system with the layers 
    if mode == 'Thick slab':
        system = GTM.System(substrate=crystal, superstrate=superstrate, layers=[])
    elif mode == 'Coherent thin film':
        system = GTM.System(substrate=substrate, superstrate=superstrate, layers=[crystal])
    else:
        system = GTM.System(substrate=substrate, superstrate=superstrate, layers=[crystal])
    # Rotate the dielectric constants to the laboratory frame
    system.substrate.set_euler(theta, phi, psi)
    system.superstrate.set_euler(theta, phi, psi)
    for layer in system.layers:
        layer.set_euler(theta, phi, psi)
    # 
    # convert cm-1 to frequency
    #
    freq = v * speed_light_si * 1e2
    system.initialize_sys(freq)
    zeta_sys = np.sin(angleOfIncidence)*np.sqrt(system.superstrate.epsilon[0,0])
    Sys_Gamma = system.calculate_GammaStar(freq, zeta_sys)
    r, R, t, T = system.calculate_r_t(zeta_sys)
    if len(system.layers) > 0:
        epsilon = system.layers[0].epsilon
    else:
        epsilon = system.substrate.epsilon
    return v,r,R,t,T,epsilon



def init_global():
    global superstrateDielectricFunction
    global substrateDielectricFunction
    global crystalPermittivityFunction
    global superstrateDepth
    global substrateDepth
    global crystalDepth
    global mode
    global theta
    global phi
    global psi
    global angleOfIncidence
    superstrateDielectric = 2.0
    substrateDielectric   = 2.0
    crystalDielectric     = 2.0 + 0.01j
    superstrateDielectricFunction = DielectricFunction(epsType='constant',units='hz',parameters=superstrateDielectric).function()
    substrateDielectricFunction   = DielectricFunction(epsType='constant',units='hz',parameters=substrateDielectric).function()
    crystalPermittivityFunction     = DielectricFunction(epsType='constant_tensor',units='hz',parameters=crystalDielectric).function()
    superstrateDepth = 80
    substrateDepth   = 80
    crystalDepth     = 80
    theta,phi,psi = 20, 20, 70
    angle = 85
    angleOfIncidence      = np.pi / 180.0 * angle
    mode = 'Thin Film'
    return 

def setup_partial_func():
    superstrateDielectric = 2.0
    substrateDielectric   = 2.0
    crystalDielectric     = 2.0 + 0.01j
    superstrateDielectricFunction = DielectricFunction(epsType='constant',units='hz',parameters=superstrateDielectric).function()
    substrateDielectricFunction   = DielectricFunction(epsType='constant',units='hz',parameters=substrateDielectric).function()
    # Actually use the permittivity function in what follows
    crystalPermittivityFunction     = DielectricFunction(epsType='constant_tensor',units='hz',parameters=crystalDielectric).function()
    # Create 3 layers, thickness is converted from microns to metres
    superstrateDepth = 80
    substrateDepth   = 80
    crystalDepth     = 80
    # Determine the euler angles
    theta,phi,psi = 20, 20, 70
    # Set the angle of incidence in radians
    angle = 85
    angleOfIncidence      = np.pi / 180.0 * angle
    mode = 'Thin Film'
    partial_func = partial(solve_single_crystal_equations_partial,
         superstrateDielectricFunction,
         substrateDielectricFunction,
         crystalPermittivityFunction,
         superstrateDepth,
         substrateDepth,
         crystalDepth,
         mode,
         theta,
         phi,
         psi,
         angleOfIncidence)
    return partial_func

def setup_local_args():
    superstrateDielectric = 2.0
    substrateDielectric   = 2.0
    crystalDielectric     = 2.0 + 0.01j
    superstrateDielectricFunction = DielectricFunction(epsType='constant',units='hz',parameters=superstrateDielectric).function()
    substrateDielectricFunction   = DielectricFunction(epsType='constant',units='hz',parameters=substrateDielectric).function()
    # Actually use the permittivity function in what follows
    crystalPermittivityFunction     = DielectricFunction(epsType='constant_tensor',units='hz',parameters=crystalDielectric).function()
    # Create 3 layers, thickness is converted from microns to metres
    superstrateDepth = 80
    substrateDepth   = 80
    crystalDepth     = 80
    # Determine the euler angles
    theta,phi,psi = 20, 20, 70
    # Set the angle of incidence in radians
    angle = 85
    angleOfIncidence      = np.pi / 180.0 * angle
    mode = 'Thin Film'
    local_args = (solve_single_crystal_equations_local,
         superstrateDielectricFunction,
         substrateDielectricFunction,
         crystalPermittivityFunction,
         superstrateDepth,
         substrateDepth,
         crystalDepth,
         mode,
         theta,
         phi,
         psi,
         angleOfIncidence)
    return local_args



def start_benchmark():
    print("""Benchmark is parallel processing""")
    global variable_type
    global parallel_type
    global threading
    global cpu_range
    global chunk_range
    if THREADS_LIMIT_ENV in os.environ:
        print("Maximum number of threads used for computation is : %s" % os.environ[THREADS_LIMIT_ENV])
    np.seterr(all='ignore')
    print(("-" * 80))
    print("Starting timing with numpy %s\nVersion: %s" % (np.__version__, sys.version))
    if variable_type == 'partial':
        func = setup_partial_func()
    elif variable_type == 'global':
        func = solve_single_crystal_equations_global
    elif variable_type == 'local':
        func = solve_single_crystal_equations_local
        init_local_args = setup_local_args()
    vs = [ v for v in np.arange(1,2000) ]
    for ncpus in cpu_range:
        if variable_type == 'partial':
            pool = get_pool(ncpus, threading, initializer=None,initargs=None)
        elif variable_type == 'global':
            pool = get_pool(ncpus, threading, initializer=init_global)
        elif variable_type == 'local':
            pool = get_pool(ncpus, threading, initializer=init_local,initargs=init_local_args)
        for chunksize in chunk_range:
            t = time.time()
            results = []
            for result in pool.map(func, vs, chunksize=chunksize):
                results.append(result)
            td = time.time() - t
            print('parallelisation = {}, variable_type = {}, ncpus = {}, threading = {}, chunksize = {}, time = {}'.format(parallel_type, variable_type, ncpus, threading, chunksize, td))
        pool.close()
        pool.join()

def usage():
    print('test_parallel [-variable_type partial|global|local] [-parallel_type multiprocess|pathos|ray] [-threading]')
    print('        Tests different types of parallelisms')
    print('        -variable_type choices are                ')
    print('                  partial - variables are passed using the partial function in functools')
    print('                  global  - variables are passed using the global variable              ')
    print('                  local   - variables are passed using a local variable in the function ')
    print('        -parallel_type choices are                ')
    print('                  multiprocess    - for the pathos multiprocessing api                     ')
    print('                  multiprocessing - for the standard python parallelism                    ')
    print('                  pathos          - for the pathos fork of multiprocess                    ')
    print('                  ray             - for the ray implementation of multiprocess             ')
    print('        -chunks   start end                       ')
    print('                  Defines the range of chunkSizes to be tested                          ')
    print('        -range    start end                       ')
    print('                  Defines the range of cpus or threads to be tested                     ')

def main():
    global variable_type
    global parallel_type
    global threading
    global cpu_range
    global chunk_range
    possible_variable_types = ['partial', 'global','local']
    possible_parallel_types = ['multiprocessing', 'multiprocess', 'pathos','ray']
    variable_type = 'partial'
    parallel_type = 'multiprocess'
    cpu_count = psutil.cpu_count(logical=False)
    cpu_range = range(1,cpu_count+1)
    chunk_range = [10, 20, 30, 40, 50, 60]
    threading = False
    tokens = sys.argv[1:]
    itoken = -1
    ntokens = len(tokens) -1
    while itoken < ntokens:
        itoken += 1
        token = tokens[itoken]
        if token == '-variable_type' or token == '--variable_type':
            itoken += 1
            variable_type = tokens[itoken]
        elif token == '-parallel_type' or token == '--parallel_type':
            itoken += 1
            parallel_type = tokens[itoken]
        elif token == '-threading' or token == '--threading':
            threading = True
        elif token == '-chunks' or token == '--chunks':
            itoken += 1; start = int(tokens[itoken])
            itoken += 1; end   = int(tokens[itoken])+1
            itoken += 1; step  = int(tokens[itoken])
            chunk_range = range(start,end,step)
        elif token == '-range' or token == '--range':
            itoken += 1; start = int(tokens[itoken])
            itoken += 1; end   = int(tokens[itoken])+1
            cpu_range = range(start,end)
        elif token == '-h' or token == '--help':
            usage()
            exit()
        else:
            usage()
            exit()
    if variable_type not in possible_variable_types:
        print('variable_type not recognised',variable_type)
        usage()
        exit()
    if parallel_type not in possible_parallel_types:
        print('parallel_type not recognised',parallel_type)
        usage()
        exit()
    if parallel_type == 'ray':
        import ray
        ray.init()
    start_benchmark()

if __name__ == '__main__':
    main()
