# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/06_quick.ipynb (unless otherwise specified).

__all__ = ['qtuneSimple', 'default_stats', 'qtuneIterate']

# Cell
from .algorithms import eaSimpleWithExtraLog
from .utils import ConcurrentMap
from .crossover import cxDictUniform
from .mutation import mutDictRand
from deap import base, creator, tools
from functools import partial
import random
import numpy

default_stats = tools.Statistics(lambda ind: ind.fitness.values)
default_stats.register("avg", numpy.mean, axis=0)
default_stats.register("std", numpy.std, axis=0)
default_stats.register("min", numpy.min, axis=0)
default_stats.register("max", numpy.max, axis=0)

def qtuneSimple(params,
                evaluate,
                n_pop=10,
                cxpb=0.6,
                mutpb=0.6,
                ngen=10,
                hof=2,
                elitism=True,
                stats=default_stats,
                crossover=partial(cxDictUniform, indpb=0.6),
                select=partial(tools.selTournament, tournsize=3),
                mutate=partial(mutDictRand, indpb=0.7),
                n_jobs=1,
                seed=None,
                verbose=__debug__):
    """Quick tune using `eaSimpleWithExtraLog`. Just provide parameter list and function to tune.
    The function given should accept keywords arguments in the parameter list.
    Check examples below for more information."""
    if seed is not None:
        random.seed(seed)

    def initParams(cls):
        return cls({i.name: next(i) for i in cls.params})

    if hof > 0:
        hof = tools.HallOfFame(hof)
    else:
        hof = None

    creator.create("eptLoss", base.Fitness, weights=(-1.0, ))
    creator.create("eptParameters",
                   dict,
                   params=params,
                   fitness=creator.eptLoss)
    toolbox = base.Toolbox()
    toolbox.register("individual", initParams, creator.eptParameters)
    toolbox.register("population", tools.initRepeat, list, toolbox.individual)
    toolbox.register('evaluate', evaluate)
    toolbox.register("select", select)
    toolbox.register("mate", crossover)
    toolbox.register("mutate", mutate, params=params)
    with ConcurrentMap(n_jobs) as pmap:
        toolbox.register('map', pmap)
        population, logbook = eaSimpleWithExtraLog(toolbox.population(n_pop),
                                   toolbox,
                                   cxpb=cxpb,
                                   mutpb=mutpb,
                                   ngen=ngen,
                                   halloffame=hof,
                                   elitism=elitism,
                                   stats=stats,
                                   verbose=verbose)
    del creator.eptLoss
    del creator.eptParameters
    return population, logbook, hof

# Cell
import queue
import threading


class qtuneIterate:
    def __init__(self,
                 params,
                 n_pop=10,
                 cxpb=0.6,
                 mutpb=0.6,
                 ngen=10,
                 hof=2,
                 elitism=True,
                 stats=default_stats,
                 crossover=partial(cxDictUniform, indpb=0.6),
                 select=partial(tools.selTournament, tournsize=3),
                 mutate=partial(mutDictRand, indpb=0.7),
                 seed=None,
                 verbose=__debug__):
        self.parameter = queue.Queue()
        self.population = None

        if seed is not None:
            random.seed(seed)

        def initParams(cls):
            return cls({i.name: next(i) for i in cls.params})

        def evaluate(params):
            def wait_fitness(params, cond):
                self.parameter.put((params, cond))
                with cond:
                    cond.wait()

            cond = threading.Condition()
            eathread = threading.Thread(name='evaluate',
                                        target=wait_fitness,
                                        args=(params, cond))
            eathread.start()
            eathread.join()
            return self._result

        if hof > 0:
            self.hof = tools.HallOfFame(hof)
        else:
            self.hof = None

        creator.create("eptLoss", base.Fitness, weights=(-1.0, ))
        creator.create("eptParameters",
                       dict,
                       params=params,
                       fitness=creator.eptLoss)
        self.toolbox = base.Toolbox()
        self.toolbox.register("individual", initParams, creator.eptParameters)
        self.toolbox.register("population", tools.initRepeat, list,
                              self.toolbox.individual)
        self.toolbox.register('evaluate', evaluate)
        self.toolbox.register("select", select)
        self.toolbox.register("mate", crossover)
        self.toolbox.register("mutate", mutate, params=params)
        # Currently this class is not threadsafe, so the sequential map is forced.
        self.toolbox.register('map', map)

        def eaWrapper():
            self.population, self.logbook = eaSimpleWithExtraLog(
                self.toolbox.population(n_pop),
                self.toolbox,
                cxpb=cxpb,
                mutpb=mutpb,
                ngen=ngen,
                halloffame=self.hof,
                elitism=elitism,
                stats=stats,
                verbose=verbose)

        self._target = eaWrapper

    def set_result(self, result, cond):
        self._result = result
        with cond:
            cond.notifyAll()

    def get_ctx(self, timeout=None):
        if self.population is None:
            while timeout is None or timeout > 0:
                try:
                    return self.parameter.get(timeout=1)
                    if timeout is not None:
                        timout -= 1
                except queue.Empty:
                    if self.population is not None:
                        raise StopIteration('Done')
        else:
            raise StopIteration('Done')

    def __next__(self):
        return self.get_ctx()

    def __iter__(self):
        return self

    def __call__(self, clear_hof=False):
        self.population = None
        if clear_hof and self.hof is not None:
            self.hof.clear()
        threading.Thread(name='eaThread', target=self._target).start()
        return self