#!/usr/bin/env python3
import argparse
import subprocess
import json
import sys
import pickle
from itertools import cycle

from rich.progress import Progress
from rich import print
from deap import creator
from deap import base
from deap import tools

from gevo import __version__
from gevo import irind
from gevo.evolve import evolution

class program(evolution):
    # Parameters
    log = open('debug_log', 'w')
    cudaPTX = 'gevo.ptx'

    def __init__(self, kernel, bin, profile, timeout=30, fitness='time',
                 llvm_src_filename='cuda-device-only-kernel.ll', err_rate=0.01):
        self.kernels = kernel
        self.appBinary = bin
        # self.appArgs = "" if args is None else args
        self.timeout = timeout
        self.fitness_function = fitness
        self.err_rate = err_rate

        try:
            with open(llvm_src_filename, 'r') as f:
                self.initSrcEnc = f.read().encode()
        except IOError:
            print("File {} does not exist".format(llvm_src_filename))
            exit(1)

        # get 1-degree mutation list 
        try:
            readline_proc = subprocess.run(['llvm-mutate', '-L'],
                                           stdout=subprocess.PIPE,
                                           stderr=subprocess.PIPE,
                                           input=self.initSrcEnc,
                                           check=True)
        except subprocess.CalledProcessError as err:
            print(err.stderr, file=sys.stderr)
            raise Exception('llvm-mutate error in getting mutation list')
        with open('/tmp/llvm_mutate_list.txt', 'r') as lmf:
            self.edits = [ eval(e) for e in lmf.readlines() ]

        self.verifier = profile['verify']

        creator.create("FitnessMin", base.Fitness, weights=(-1.0, -1.0))
        creator.create("Individual", irind.llvmIRrep, fitness=creator.FitnessMin)
        self.toolbox = base.Toolbox()
        self.toolbox.register('individual', creator.Individual, srcEnc=self.initSrcEnc)
        self.toolbox.register('population', tools.initRepeat, list, self.toolbox.individual)

        # Set up testcase
        self.origin = creator.Individual(self.initSrcEnc)
        self.origin.ptx(self.cudaPTX)
        arg_array = [[]]
        for i, arg in enumerate(profile['args']):
            if arg.get('bond', None) is None:
                arg_array_next = [ e[:] for e in arg_array for _ in range(len(arg['value']))]
                arg_array = arg_array_next
                for e1, e2 in zip(arg_array, cycle(arg['value'])):
                    e1.append(e2)
            else:
                for e in arg_array:
                    bonded_arg = arg['bond'][0]
                    bonded_idx = profile['args'][bonded_arg]['value'].index(e[bonded_arg])
                    e.append(arg['value'][bonded_idx])

        arg_array = [ [str(e) for e in args ] for args in arg_array ]

        self.testcase = []
        for i in range(len(arg_array)):
            self.testcase.append( self._testcase(self, i, kernel, bin, profile['verify']) )
        print("evalute testcase as golden..", end='', flush=True)
        for i, (tc, arg) in enumerate(zip(self.testcase, arg_array)):
            tc.args = arg
            print("{}..".format(i+1), end='', flush=True)
            tc.evaluate()
        print("done", flush=True)

        fits = [ tc.fitness[0] for tc in self.testcase]
        errs = [ tc.fitness[1] for tc in self.testcase]
        self.origin.fitness.values = (sum(fits)/len(fits), max(errs))
        self.editFitMap[None] = self.origin.fitness.values
        print("Fitness of the original program: {}".format(self.origin.fitness.values))

    def edittest(self):
        # self.pop = self.toolbox.population(n=len(self.edits))
        status = {}
        status['pass'] = 0
        status['fail'] = 0
        fitness = [None] * len(self.edits)
        with Progress(auto_refresh=False) as pbar:
            task1 = pbar.add_task("", total=len(self.edits))
            for cnt, (edit, _) in enumerate(zip(self.edits, fitness)):
                pbar.update(task1, completed=cnt, refresh=True,
                            description=f"P:{status['pass']}/F:{status['fail']} ({cnt}/{len(self.edits)})")
                ind = creator.Individual(self.initSrcEnc)
                ind.edits = [edit]
                if ind.update_from_edits() == False:
                    print(f"{edit}: Cannot compile")
                    continue
                fit, err = self.evaluate(ind)
                if fit is None or err is None:
                    status['fail'] = status['fail'] + 1
                    print(f"{edit}: Failed")
                    continue
                improvement = self.origin.fitness.values[0]/fit
                print(f"{edit}: {fit:.3f}s ({improvement*100:.2f}%). Error: {err:.3f}")
                status['pass'] = status['pass'] + 1

        with open("mutationMap.pickle", 'wb') as emfile:
            pickle.dump(self.editFitMap, emfile)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Explore the fitness space composed by all the 1-degree mutation to the given program")
    parser.add_argument('-P', '--profile_file', type=str, required=True,
        help="Specify the profile file that contains all application execution and testing information")
    parser.add_argument('-t', '--timeout', type=int, default=30,
        help="The timeout period to evaluate the CUDA application")
    parser.add_argument('-fitf', '--fitness_function', type=str, default='time',
        help="What is the target fitness for the evolution. Default ot execution time. Can be changed to power")
    parser.add_argument('--err_rate', type=float, default='0.01',
        help="Allowed maximum relative error generate from mutant comparing to the origin")
    parser.add_argument('--version', action='version', version='gevo-' + __version__)
    args = parser.parse_args()

    try:
        profile = json.load(open(args.profile_file))
    except:
        print(sys.exc_info())
        exit(-1)

    alyz = program(
        kernel=profile['kernels'],
        bin=profile['binary'],
        profile=profile,
        timeout=args.timeout,
        fitness=args.fitness_function,
        err_rate=args.err_rate)

    print("      Target CUDA program: {}".format(profile['binary']))
    print("Args for the CUDA program:")
    for tc in alyz.testcase:
        print("\t{}".format(" ".join(tc.args)))
    print("           Target kernels: {}".format(" ".join(profile['kernels'])))
    print("       Evaluation Timeout: {}".format(args.timeout))
    print("         Fitness function: {}".format(args.fitness_function))
    print("      Tolerate Error Rate: {}".format(args.err_rate))

    try:
        alyz.edittest()
    except KeyboardInterrupt:
        subprocess.run(['killall', profile['binary']])
        with open("mutationMap.pickle", 'wb') as emfile:
            pickle.dump(alyz.editFitMap, emfile)
