#!/usr/bin/env python3
import argparse
import subprocess
import json
import sys
import logging
import ast

from rich.logging import RichHandler

from gevo.evolve import evolution

logging.basicConfig(format="%(message)s" ,level="NOTSET" ,handlers=[RichHandler()])
log=logging.getLogger("main")

class program(evolution):
    def __init__(self, editf, kernel, bin, profile, timeout=30, fitness='time',
                 llvm_src_filename='cuda-device-only-kernel.ll', err_rate=0.01):
        super().__init__(
            kernel=kernel,
            bin=bin,
            profile=profile,
            timeout=timeout,
            fitness=fitness,
            err_rate=err_rate,
            mutop='' )

        try:
            with open(editf, 'r') as f:
                self.edits = ast.literal_eval(f.read())
        except FileNotFoundError:
            log.error(f"Edit File:{editf} cannot be found")
            sys.exit(1)

        self.origin.edits = self.edits
        if self.origin.update_from_edits() is False:
            raise Exception("Edit file cannot be compiled")
        print("Fitness of the program with all edits: {}".format(self.evaluate(self.origin)))

    def edittest(self):
        self.pop = self.toolbox.population(n=len(self.edits))
        fitness = [None] * len(self.edits)
        for ind,edit,fits in zip(self.pop, self.edits, fitness):
            ind.edits = [edit]
            print("Evalute edit: {}".format(edit), end='', flush=True)
            if ind.update_from_edits() == False:
                print(": cannot compile")
                continue
            fits = [self.evaluate(ind)[0] for i in range(3)]
            errs = [self.evaluate(ind)[1] for i in range(3)]
            if None in fits:
                print(": execution failed")
                continue
            fit = float(sum(fits)) / len(fits)
            err = float(sum(errs)) / len(errs)
            improvement = self.origin.fitness.values[0]/fit
            print(": {}. Improvement: {}. Error:{}".format(fit, improvement, err))

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Analyze the performance of mutation edits for CUDA kernel")
    parser.add_argument('-P', '--profile_file', type=str, required=True,
        help="Specify the profile file that contains all application execution and testing information")
    parser.add_argument('-e', '--edit', type=str, required=True,
        help="The edit file")
    parser.add_argument('-t', '--timeout', type=int, default=30,
        help="The timeout period to evaluate the CUDA application")
    parser.add_argument('-fitf', '--fitness_function', type=str, default='time',
        help="What is the target fitness for the evolution. Default ot execution time. Can be changed to power")
    parser.add_argument('--err_rate', type=float, default='0.01',
        help="Allowed maximum relative error generate from mutant comparing to the origin")
    args = parser.parse_args()

    try:
        profile = json.load(open(args.profile_file))
    except FileNotFoundError:
        log.error(f"The profile:'{args.profile_file}' cannot be found")
    except:
        print(sys.exc_info())
        exit(-1)

    alyz = program(
        editf=args.edit,
        kernel=profile['kernels'],
        bin=profile['binary'],
        profile=profile,
        timeout=args.timeout,
        fitness=args.fitness_function,
        err_rate=args.err_rate)

    print("      Target CUDA program: {}".format(profile['binary']))
    print("Args for the CUDA program:")
    for tc in alyz.testcase:
        print("\t{}".format(" ".join(tc.args)))
    print("           Target kernels: {}".format(" ".join(profile['kernels'])))
    print("       Evaluation Timeout: {}".format(args.timeout))
    print("         Fitness function: {}".format(args.fitness_function))
    print("                Edit file: {}".format(args.edit))
    print("      Tolerate Error Rate: {}".format(args.err_rate))

    try:
        alyz.edittest()
    except KeyboardInterrupt:
        subprocess.run(['killall', profile['binary']])
