config['base_dir'] = "{{ base_dir }}"


def promote_best(input, output, new_epochs):
    '''
    find the len(output) best configurations in the input, and moves
    them to the given output.

    input is a list of paths pointing to the result files, while output
    is a list of paths pointing to the new configuration files
    '''
    import os, shutil

    scores = []
    for i, fname in enumerate(input):
        with open(fname) as f:
              scores.append((fname, float(f.read())))

    scores.sort(key=lambda x: x[1])

    print('Promoted configuration and their scores')
    for i in range(len(output)):
        config, score = scores[i]
        print('    %.4f - %s' % (score, config))
        indir, _ = os.path.split(config)
        shutil.copy(os.path.join(indir, 'config'), output[i])


rule all:
    input:
        expand("{d}/config", d=config['base_dir'])


# check results of last stage of each bracket,
# and copy absolute best to base directory
rule find_overall_best:
    input:
{% for bracket in brackets %}
        expand("{d}/bracket-{{ bracket.id }}/stage-{{ bracket.max_stage - 1 }}/config-{c}/result",
               d=config['base_dir'], c=range({{ bracket.num_best }})),
{% endfor %}
    output:
        expand("{d}/config", d=config['base_dir'])
    run:
        promote_best(input, output, new_epochs=None)

{% for bracket in brackets %}
# -----------------------------------------------------------------------------------------------------------
# ---  Begin Bracket {{ bracket.id }}
# -----------------------------------------------------------------------------------------------------------

{% for stage in bracket.stages  %}
# run all configurations in stage {{ stage.id }}
rule run_bracket_{{ bracket.id }}_stage_{{ stage.id }}:
    input:
         expand("{d}/bracket-{{ bracket.id }}/stage-{{ stage.id }}/config-{c}/config",
                c=range({{ stage.configs }}), d=config['base_dir'])
    output:
         expand("{d}/bracket-{{ bracket.id }}/stage-{{ stage.id }}/config-{{'{{'}}c{{'}}'}}/result",
                d=config['base_dir'])
    shell:
         expand("bash {d}/run.sh {d}/bracket-{{ bracket.id }}/stage-{{ stage.id }}/config-{{'{{'}}wildcards.c{{'}}'}} {b}",
                d=config['base_dir'], b={{ stage.budget }})[0]

{% if stage.promote_count > 0 %}
# promote best configurations to next stage
rule promote_bracket_{{ bracket.id }}_stage_{{ stage.id }}:
    input:
        expand("{d}/bracket-{{ bracket.id }}/stage-{{ stage.id }}/config-{c}/result",
                c=range({{ stage.configs }}), d=config['base_dir'])
    output:
        expand("{d}/bracket-{{ bracket.id }}/stage-{{ stage.id + 1 }}/config-{c}/config",
                c=range({{ stage.promote_count }}), d=config['base_dir'])
    run:
        promote_best(input, output, new_epochs={{ stage.promote_epochs }})

{% endif %}
{% endfor %}
{% endfor %}
