#!/usr/bin/env python3
import argparse
import subprocess
import json
import sys
import csv
import filecmp
import os
from io import StringIO
from itertools import cycle

from deap import creator
from deap import base
from deap import tools

from gevo import irind
from gevo.evolve import evolution

class program(evolution):
    # Parameters
    cudaPTX = 'a.ptx'

    def __init__(self, kernel, bin, profile, timeout=30, fitness='time',
                 llvm_src_filename='cuda-device-only-kernel.ll', err_rate=0.01):
        self.kernels = kernel
        self.appBinary = bin
        # self.appArgs = "" if args is None else args
        self.timeout = timeout
        self.fitness_function = fitness
        self.err_rate = err_rate

        try:
            with open(llvm_src_filename, 'r') as f:
                self.initSrcEnc = f.read().encode()
        except IOError:
            print("File {} does not exist".format(llvm_src_filename))
            exit(1)

        self.verifier = profile['verify']

        creator.create("FitnessMin", base.Fitness, weights=(-1.0, -1.0))
        creator.create("Individual", irind.llvmIRrep, fitness=creator.FitnessMin)
        self.toolbox = base.Toolbox()
        self.toolbox.register('individual', creator.Individual, srcEnc=self.initSrcEnc)
        self.toolbox.register('population', tools.initRepeat, list, self.toolbox.individual)

        # Set up testcase
        self.origin = creator.Individual(self.initSrcEnc)
        self.origin.ptx(self.cudaPTX)
        arg_array = [[]]
        for i, arg in enumerate(profile['args']):
            if arg.get('bond', None) is None:
                arg_array_next = [ e[:] for e in arg_array for _ in range(len(arg['value']))]
                arg_array = arg_array_next
                for e1, e2 in zip(arg_array, cycle(arg['value'])):
                    e1.append(e2)
            else:
                for e in arg_array:
                    bonded_arg = arg['bond'][0]
                    bonded_idx = profile['args'][bonded_arg]['value'].index(e[bonded_arg])
                    e.append(arg['value'][bonded_idx])

        arg_array = [ [str(e) for e in args ] for args in arg_array ]

        self.testcase = []
        for i in range(len(arg_array)):
            self.testcase.append( self._testcase(self, i, kernel, bin, profile['verify']) )
        print("evalute testcase as golden..", end='', flush=True)
        for i, (tc, arg) in enumerate(zip(self.testcase, arg_array)):
            tc.args = arg
            print("{}..".format(i+1), end='', flush=True)
            tc.evaluate()
        print("done", flush=True)

        self.ofits = [ tc.fitness[0] for tc in self.testcase]
        self.oerrs = [ tc.fitness[1] for tc in self.testcase]
        self.origin.fitness.values = (sum(self.ofits)/len(self.ofits), max(self.oerrs))
        print("Average Fitness of the original program: {}".format(self.origin.fitness.values[0]))
        print("Individual test cases:")
        for fit, err in zip(self.ofits, self.oerrs):
            print("{}, {}".format(fit, err))

    def evaluate(self, llvm_file):
        try:
            with open(llvm_file, 'r') as f:
                llvmSrcEnc = f.read().encode()
        except IOError:
            print("File {} does not exist".format(llvm_file))
            exit(1)

        individual = creator.Individual(llvmSrcEnc)

        # link
        try:
            individual.ptx(self.cudaPTX)
        except:
            self.editFitMap[editkey] = (None, None)
            return None, None

        with open('a.ll', 'w') as f:
            f.write(individual.srcEnc.decode())

        fits = []
        errs = []
        for tc in self.testcase:
            fitness, err = self.execNVprofRetrive(tc)

            for res_file in self.verifier['output']:
                if os.path.exists(res_file):
                    os.remove(res_file)

            fits.append(fitness)
            errs.append(err)

        max_err = max(errs)
        avg_fitness = sum(fits)/len(fits)
        # record the edits and the corresponding fitness in the map
        print("Average Fitness of tested llvm program: {}, {}, {}".format(avg_fitness, max_err,
            self.origin.fitness.values[0]/avg_fitness*100))
        print("Individual test cases:")
        for fit, err, ofit in zip(fits, errs, self.ofits):
            print("{}, {}, {}".format(fit, err, ofit/fit*100))

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Analyze the performance of mutation edits for CUDA kernel")
    parser.add_argument('-P', '--profile_file', type=str, required=True,
        help="Specify the profile file that contains all application execution and testing information")
    parser.add_argument('-l', '--llvm_file', type=str, required=True,
        help="The edit file")
    parser.add_argument('-t', '--timeout', type=int, default=30,
        help="The timeout period to evaluate the CUDA application")
    parser.add_argument('-fitf', '--fitness_function', type=str, default='time',
        help="What is the target fitness for the evolution. Default ot execution time. Can be changed to power")
    parser.add_argument('--err_rate', type=float, default='0.01',
        help="Allowed maximum relative error generate from mutant comparing to the origin")
    args = parser.parse_args()

    try:
        profile = json.load(open(args.profile_file))
    except:
        print(sys.exc_info())
        exit(-1)

    print("      Target CUDA program: {}".format(profile['binary']))
    print("Args for the CUDA program:")
    print("           Target kernels: {}".format(" ".join(profile['kernels'])))
    print("       Evaluation Timeout: {}".format(args.timeout))
    print("         Fitness function: {}".format(args.fitness_function))
    print("      Tolerate Error Rate: {}".format(args.err_rate))
    print("      Target LLVM-IR file: {}".format(args.llvm_file))

    alyz = program(
        kernel=profile['kernels'],
        bin=profile['binary'],
        profile=profile,
        timeout=args.timeout,
        fitness=args.fitness_function,
        err_rate=args.err_rate)

    for tc in alyz.testcase:
        print("\t{}".format(" ".join(tc.args)))

    try:
        alyz.evaluate(args.llvm_file)
    except KeyboardInterrupt:
        subprocess.run(['killall', args.binary])
