import os
import os.path
import logging
import random
import subprocess
import shlex
import gzip
import re
import functools
import time
import imp
import sys
del sys.modules['pickle']

sys.modules['pickle'] = imp.load_module('pickle', *imp.find_module('pickle'))

import pickle
import steamroller.scons


vars = Variables("steamroller_config.py")
vars.AddVariables(
    ("OUTPUT_WIDTH", "Upper limit on how long a debugging line will be before it's truncated", 1000),
    ("VERBOSE", "Whether to print the full commands being executed", False),
    ("DEFAULTS", "General variables (potentially overridden by models and tasks)", {}),
    ("MODELS", "Classification models to compare", []),
    ("TASKS", "Classification tasks", []),
    ("TEST_COUNT", "Data size for testing models", 10000),
    BoolVariable("GRID", "Do we have access to a grid via the qsub command?", False),
    ("GRID_RESOURCES", "List of resources to request for a job", []),
    ("GRID_CHECK_INTERVAL", "How many seconds between checking on job status", 30),
)

def print_cmd_line(s, target, source, env):
    if len(s) > int(env["OUTPUT_WIDTH"]):
        print s[:int(env["OUTPUT_WIDTH"]) - 10] + "..." + s[-7:]
    else:
        print s


env = Environment(variables=vars, ENV=os.environ, TARFLAGS="-c -z", TARSUFFIX=".tgz",
                  tools=["default", steamroller.scons.generate],
)


logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%Y-%m-%d %H:%M:%S', level=logging.INFO)


defaults = env["DEFAULTS"]


env['PRINT_CMD_LINE_FUNC'] = print_cmd_line
env.Decider("timestamp-newer")

for task in env["TASKS"]:
    classified_items = []
    train_resource_list = []
    apply_resource_list = []
    model_list = []
    task_name = task["name"]
    train_proportion = task.get("train_proportion", defaults.get("train_proportion", .9))
    input_file = env.File(task["file"])
    count_file, _ = env.GetCount("work/${TASK_NAME}_total.txt.gz", input_file, TASK_NAME=task_name)

    
    for train_count in task.get("sizes", defaults.get("sizes", [])):
        for fold in range(1, task.get("folds", defaults.get("folds", 1)) + 1):

            train, test, _ = env.CreateSplit(["work/${TASK_NAME}_train_${FOLD}_${TRAIN_COUNT}_${TEST_COUNT}.txt.gz",
                                           "work/${TASK_NAME}_test_${FOLD}_${TRAIN_COUNT}_${TEST_COUNT}.txt.gz"],
                                          count_file, FOLD=fold, TRAIN_COUNT=train_count, TASK_NAME=task_name)

            for model in env["MODELS"]:
                model_name = model["name"]
                train_builder = env["BUILDERS"]["Train%s" % model_name]
                apply_builder = env["BUILDERS"]["Apply%s" % model_name]
                model_file, resources = train_builder(env,
                                                      "work/${TASK_NAME}_${MODEL_NAME}_${TRAIN_COUNT}_${FOLD}.model.gz",
                                                      [train, input_file],
                                                      FOLD=fold, TRAIN_COUNT=train_count, TASK_NAME=task_name, MODEL_NAME=model_name,
                                                      GRID_RESOURCES=model.get("grid_resources", env["GRID_RESOURCES"]),
                                                  )


                train_resource_list.append(resources)
                model_list.append(model_file)
                classified, _ = apply_builder(env,
                                              "work/${TASK_NAME}_${MODEL_NAME}_${TRAIN_COUNT}_${FOLD}_probabilities.txt.gz",
                                              [model_file, test, input_file],
                                              FOLD=fold, TRAIN_COUNT=train_count, TASK_NAME=task_name, MODEL_NAME=model_name,
                                              GRID_RESOURCES=model.get("grid_resources", []),
                )
                apply_resource_list.append(resources)
                classified_items.append(classified)
                continue
    plots = []
    if len(classified_items) > 0:
        scores, _ = env.Evaluate("work/%s_scores.txt.gz" % (task_name), classified_items)
        train_resources, _ = env.CollateResources("work/%s_trainresources.txt.gz" % (task_name), train_resource_list)
        apply_resources, _ = env.CollateResources("work/%s_applyresources.txt.gz" % (task_name), apply_resource_list)
        model_sizes, _ = env.ModelSizes("work/%s_modelsizes.txt.gz" % (task_name), model_list)
        env.Plot("work/%s_trainmemory_plot.png" % (task_name), train_resources, FIELD="Memory", TITLE="Max memory (G) training")
        env.Plot("work/%s_traincpu_plot.png" % (task_name), train_resources, FIELD="CPU", TITLE="CPU Time (s) training")
        env.Plot("work/%s_applymemory_plot.png" % (task_name), apply_resources, FIELD="Memory", TITLE="Max memory (G) applied to 10k instances")
        env.Plot("work/%s_applycpu_plot.png" % (task_name), apply_resources, FIELD="CPU", TITLE="CPU time (s) applied to 10k instances")
        env.Plot("work/%s_modelsize_plot.png" % (task_name), model_sizes, FIELD="Gigabytes", TITLE="Model size (G)")
        env.Plot("work/%s_fscore_plot.png" % (task_name), scores, FIELD="F_Score", TITLE="F-Score")
