#!/usr/bin/env python3
import argparse
import subprocess
import json
import sys
import os

from deap import creator
from rich.table import Table
from rich import print as rprint

from gevo import __version__
from gevo.evolve import evolution

class program(evolution):
    # Parameters
    cudaPTX = 'gevo.ptx'

    def __init__(self, kernel, bin, profile, timeout=30, fitness='time',
                 llvm_src_filename='cuda-device-only-kernel.ll', err_rate=0.01):
        super().__init__(
            kernel=kernel,
            bin=bin,
            profile=profile,
            timeout=timeout,
            fitness=fitness,
            err_rate=err_rate,
            mutop='',
            use_fitness_map=False )

    def evaluate(self, llvm_file):
        try:
            with open(llvm_file, 'r') as f:
                llvmSrcEnc = f.read().encode()
        except IOError:
            print("File {} does not exist".format(llvm_file))
            exit(1)

        individual = creator.Individual(llvmSrcEnc, self.mgpu)

        # Need to retrive individual test case result which is why the built-in evluation function is not used
        # link
        individual.ptx(self.cudaPTX)

        with open('gevo.ll', 'w') as f:
            f.write(individual.srcEnc.decode())

        fits = []
        errs = []
        for tc in self.testcase:
            fitness, err = self.execCuprofRetrive(tc)
            if fitness is not None:
                fitness2, err2 = self.execCuprofRetrive(tc)
                fitness3, err3 = self.execCuprofRetrive(tc)
                fitness = min([fitness, fitness2, fitness3])
                err = min([err, err2, err3])
            else:
                print("Tested llvm program failed to execute")
                exit(0)

            for res_file in self.verifier['output']:
                if os.path.exists(res_file):
                    os.remove(res_file)

            fits.append(fitness)
            errs.append(err)

        max_err = max(errs)
        avg_fitness = sum(fits)/len(fits)
        # record the edits and the corresponding fitness in the map
        rprint("Average Fitness of tested llvm program: {}, {}, {}".format(avg_fitness, max_err,
            self.origin.fitness.values[0]/avg_fitness*100))
        rprint("Individual test cases:")
        for fit, err, ofit in zip(fits, errs, self.ofits):
            rprint(f"({fit:.2f}, {err:.2f}), Improvement: {ofit/fit*100:.2f}")

        rprint("Test LLVM Opt optimized src: ", end="")
        fits = []
        errs = []
        individual.optimize_src()
        individual.ptx(self.cudaPTX)
        with open('gevo.ll', 'w') as f:
            f.write(individual.srcEnc.decode())
        for tc in self.testcase:
            fitness, err = self.execCuprofRetrive(tc)
            if fitness is not None:
                rprint("Pass!")
            else:
                rprint("Failed!")
                exit(0)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Analyze the performance of mutation edits for CUDA kernel")
    parser.add_argument('-P', '--profile_file', type=str, required=True,
        help="Specify the profile file that contains all application execution and testing information")
    parser.add_argument('-l', '--llvm_file', type=str, required=True,
        help="The edit file")
    parser.add_argument('-t', '--timeout', type=int, default=30,
        help="The timeout period to evaluate the CUDA application")
    parser.add_argument('-fitf', '--fitness_function', type=str, default='time',
        help="What is the target fitness for the evolution. Default ot execution time. Can be changed to power")
    parser.add_argument('--err_rate', type=str, default='0.01',
        help="Allowed maximum relative error generate from mutant comparing to the origin")
    parser.add_argument('--version', action='version', version='gevo-' + __version__)
    args = parser.parse_args()

    try:
        profile = json.load(open(args.profile_file))
    except:
        print(sys.exc_info())
        exit(-1)

    alyz = program(
        kernel=profile['kernels'],
        bin=profile['binary'],
        profile=profile,
        timeout=args.timeout,
        fitness=args.fitness_function,
        err_rate=args.err_rate)

    table = Table.grid()
    table.add_column(justify="right", style="bold blue")
    table.add_column()
    table.add_row("Target CUDA program: ", profile['binary'])
    tc_args = ""
    for tc in alyz.testcase:
        tc_args = tc_args + "{}".format(" ".join(tc.args)) + '\n'
    table.add_row("Args for the CUDA program: ", tc_args)
    table.add_row("Target kernels: ", " ".join(profile['kernels']))
    table.add_row("Evaluation Timeout: ", str(args.timeout))
    table.add_row("Fitness function: ", args.fitness_function)
    table.add_row("Tolerate Error Rate: ", str(args.err_rate))
    table.add_row("Target LLVM-IR file: ", args.llvm_file)
    rprint(table)

    try:
        alyz.evaluate(args.llvm_file)
    except KeyboardInterrupt:
        subprocess.run(['killall', args.binary])
