#!/usr/bin/env python3
import argparse
import subprocess
import json
import sys
import logging
import ast
import itertools

from rich.logging import RichHandler
from rich.table import Table
from rich import print as rprint

from gevo import __version__
from gevo.evolve import evolution
from gevo.irind import edits_as_key

logging.basicConfig(format="%(message)s" ,level="NOTSET" ,handlers=[RichHandler()])
log=logging.getLogger("main")

class program(evolution):
    def __init__(self, editf, kernel, bin, profile, timeout=30, fitness='time',
                 llvm_src_filename='cuda-device-only-kernel.ll', err_rate=0.01):
        super().__init__(
            kernel=kernel,
            bin=bin,
            profile=profile,
            timeout=timeout,
            fitness=fitness,
            err_rate=err_rate,
            mutop='',
            use_fitness_map=False )

        try:
            with open(editf, 'r') as f:
                self.edits = ast.literal_eval(f.read())
        except FileNotFoundError:
            log.error(f"Edit File:{editf} cannot be found")
            sys.exit(1)

        self.origin.edits = self.edits
        if self.origin.update_from_edits() is False:
            raise Exception("Edit file cannot be compiled")

    def evaluate_full_edits(self):
        print(f"Evaluate edit file", end="", flush=True)
        self.fullEditsInd = self.toolbox.individual()
        self.fullEditsInd.edits = self.edits
        if self.fullEditsInd.update_from_edits() is False:
            raise Exception("Edit file cannot be compiled")
        fitness_values = [self.evaluate(self.fullEditsInd) for i in range(3)]
        fit = min([value[0] for value in fitness_values])
        err = min([value[1] for value in fitness_values])
        self.fullEditsInd.fitness.values = (fit, err)
        log.info(f"Fitness of the program with all edits: {self.fullEditsInd.fitness}")    
    
    def edittest(self):
        self.pop = self.toolbox.population(n=len(self.edits))
        fitness = [None] * len(self.edits)
        for ind,edit,fits in zip(self.pop, self.edits, fitness):
            ind.edits = [edit]
            print(f"{edit}", end='', flush=True)
            if ind.update_from_edits() == False:
                print(": cannot compile")
                continue
            fits = [self.evaluate(ind)[0] for i in range(3)]
            errs = [self.evaluate(ind)[1] for i in range(3)]
            if None in fits:
                print(": execution failed")
                continue
            fit = min(fits)
            err = min(errs)
            improvement = self.origin.fitness.values[0]/fit
            print(f": {fit:.2f}. Improvement: {improvement:.2f}. Error:{err:.2f}")

    def group_test(self):
        editIdxMap = {edits_as_key([edit]): cnt for cnt, edit in enumerate(self.edits)}
        
        grid = Table.grid()
        grid.add_column(justify="right", style="bold blue")
        grid.add_column()
        for cnt, edit in enumerate(self.edits):
            grid.add_row(str(cnt)+': ', str(edit))
        rprint(grid)

        for l in range(2, len(self.edits)):
            for subEdits in itertools.combinations(self.edits, l): 
                subEditsInd = self.toolbox.individual()
                subEditsInd.edits = subEdits
                if subEditsInd.update_from_edits() is False:
                    continue
                fitness_values = [self.evaluate(subEditsInd) for i in range(3)]
                if None in [value[0] for value in fitness_values]:
                    continue
                fit = min([value[0] for value in fitness_values])
                err = min([value[1] for value in fitness_values])
                subEditsInd.fitness.values = (fit, err)
                improvement = self.origin.fitness.values[0] / fit
                if improvement < 1.01:
                    continue

                print('[ ', end='')
                for edit in subEdits:
                    print(f"{editIdxMap[edits_as_key([edit])]} ", end='')
                print(f']: ({fit:.2f}, {err:.2f}), Imp:{improvement:.2f}')

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Analyze the performance of mutation edits for CUDA kernel")
    parser.add_argument('-P', '--profile_file', type=str, required=True,
        help="Specify the profile file that contains all application execution and testing information")
    parser.add_argument('-e', '--edit', type=str, required=True,
        help="The edit file")
    parser.add_argument('-t', '--timeout', type=int, default=30,
        help="The timeout period to evaluate the CUDA application")
    parser.add_argument('-fitf', '--fitness_function', type=str, default='time',
        help="What is the target fitness for the evolution. Default ot execution time. Can be changed to power")
    parser.add_argument('--err_rate', type=float, default='0.01',
        help="Allowed maximum relative error generate from mutant comparing to the origin")
    parser.add_argument('--version', action='version', version='gevo-' + __version__)
    args = parser.parse_args()

    try:
        profile = json.load(open(args.profile_file))
    except FileNotFoundError:
        log.error(f"The profile:'{args.profile_file}' cannot be found")
    except:
        print(sys.exc_info())
        exit(-1)

    alyz = program(
        editf=args.edit,
        kernel=profile['kernels'],
        bin=profile['binary'],
        profile=profile,
        timeout=args.timeout,
        fitness=args.fitness_function,
        err_rate=args.err_rate)

    table = Table.grid(expand=True)
    table.add_column(justify="right", style="bold blue")
    table.add_column()
    table.add_row("Target CUDA program: ", profile['binary'])
    tc_args = ""
    for tc in alyz.testcase:
        tc_args = tc_args + "{}".format(" ".join(tc.args)) + '\n'
    table.add_row("Args for the CUDA program:: ", tc_args)
    table.add_row("Target kernels:: ", " ".join(profile['kernels']))
    table.add_row("Evaluation Timeout:: ", str(args.timeout))
    table.add_row("Fitness function:: ", args.fitness_function)
    table.add_row("Edit file:: ", args.edit)
    table.add_row("Tolerate Error Rate:: ", str(args.err_rate))
    rprint(table)

    try:
        alyz.evaluate_full_edits()
        print("Evaluating each individual edit")
        alyz.edittest()

        print("Evaluating all edit combinations iteratively")
        alyz.group_test()
    except KeyboardInterrupt:
        subprocess.run(['killall', profile['binary']])
