#!/usr/bin/env python3
# pylint: disable=logging-fstring-interpolation

import argparse
import subprocess
import json
import sys
import logging
import ast
import itertools
import pickle
from copy import deepcopy

from rich.logging import RichHandler
from rich.table import Table
from rich.progress import Progress
from rich import print as rprint

from gevo import __version__
from gevo.evolve import evolution
import gevo.irind as irind
from gevo.irind import edits_as_key

logging.basicConfig(format="%(message)s" ,level="NOTSET" ,handlers=[RichHandler()])
log=logging.getLogger("main")

class program(evolution):
    def __init__(self, editf, kernel, bin, profile, timeout=30, fitness='time',
                 editmap=None,
                 llvm_src_filename='cuda-device-only-kernel.ll', err_rate=0.01):
        super().__init__(
            kernel=kernel,
            bin=bin,
            profile=profile,
            timeout=timeout,
            fitness=fitness,
            err_rate=err_rate,
            mutop='',
            use_fitness_map=False )

        if editmap is not None:
            try:
                with open(editmap, 'rb') as editFitMapFile:
                    print('[resuming GEVO] Previous EditFitMap found ... ', end='')
                    self.editFitMap = pickle.load(editFitMapFile)
                    print(f'{len(self.editFitMap)} entries loaded')
            except FileNotFoundError:
                log.error(f"Editmapfile:{editmap} cannot be found")
                sys.exit(1)
        
        try:
            with open(editf, 'r') as f:
                self.edits = ast.literal_eval(f.read())
                self.edits = irind.encode_edits_from_list(self.edits)
        except FileNotFoundError:
            log.error(f"Edit File:{editf} cannot be found")
            sys.exit(1)

        rprint("Evaluate edit file ", end="", flush=True)
        try:
            self.fullEditsInd = self.toolbox.individual(edits = self.edits)
        except irind.llvmIRrepRuntimeError as err:
            raise Exception("Edit file cannot be compiled") from err
        fitness_values = [self.evaluate(self.fullEditsInd) for i in range(3)]
        if None in [value[0] for value in fitness_values]:
            raise Exception("Edit file fails the verification")
        fit = min([value[0] for value in fitness_values])
        err = min([value[1] for value in fitness_values])
        self.fullEditsInd.fitness.values = (fit, err)
        rprint("")
        log.info(f"Fitness of the program with all edits: {self.fullEditsInd.fitness}")    

    def remove_weak_edits(self, edits, threshold=0.01):
        '''
        :param edits: removing weak edit from edits
        :returns: The useful edits
        '''
        log.info("Start removing weak edits ...")
        with Progress(auto_refresh=False) as pbar:
            task1 = pbar.add_task("", total=len(edits))
            pbar.update(task1, completed=0, refresh=True,
                        description=f"(0/{len(edits)})")

            removal_list = []
            indPrior = deepcopy(self.fullEditsInd)
            for cnt, edit in enumerate(edits):
                try:
                    indPriorwoEdit = self.toolbox.individual(edits=[e for e in indPrior.edits if e != edit])
                except irind.llvmIRrepRuntimeError:
                    log.info(f"{edit[0]} not removed: cannot compile")
                    continue

                fitness_values = [self.evaluate(indPriorwoEdit) for i in range(3)]
                if None in [value[0] for value in fitness_values]:
                    log.info(f"{edit[0]} not removed: execution failed")
                    continue

                fit = min([value[0] for value in fitness_values])
                err = max([value[1] for value in fitness_values])
                indPriorwoEdit.fitness.values = (fit, err)
                improvement = indPrior.fitness.values[0]/fit
                # if improvement > 1-threshold and abs(err - indPrior.fitness.values[1]) < 0.01*threshold:
                if improvement > 1-threshold :
                    removal_list.append(edit)
                    log.info(f"{edit[0]} removed: {fit:.2f}. Improvement: {improvement:.2f}. Error:{err}")
                    indPrior = indPriorwoEdit
                else:
                    log.info(f"{edit[0]} not removed: {fit:.2f}. Improvement: {improvement:.2f}. Error:{err}")

                pbar.update(task1, completed=cnt+1, refresh=True,
                            description=f"({cnt+1}/{len(edits)})")
        
        with open("reduced.edit", "w") as f:
            rprint(indPrior.edits, file=f)
        with open("reduced.ll", "w") as f:
            f.write(indPrior.srcEnc.decode())
        log.info("Done writing reduced.edit and reduced.ll")
        log.info(f"Fitness of the edit-reduced program: {indPrior.fitness}")
        return indPrior.edits
    
    def edittest(self, edits):
        fitness = [None] * len(edits)
        for edit,fits in zip(edits, fitness):
            try:
                ind = self.toolbox.individual(edits=[edit])
            except irind.llvmIRrepRuntimeError:
                rprint(f"{edit}: cannot compile")
                continue
            fitness_values = [self.evaluate(ind) for i in range(3)]
            fits = [value[0] for value in fitness_values]
            errs = [value[1] for value in fitness_values]
            if None in fits:
                rprint(f"{edit}: execution failed")
                continue
            if None in errs:
                rprint(f"{edit}: execution failed")
                continue
            fit = min(fits)
            err = max(errs)
            improvement = self.origin.fitness.values[0]/fit
            rprint(f"{edit}: {fit:.2f}. Improvement: {improvement:.2f}. Error:{err:.2f}")

    def search_indepedent_edits(self, edits):
        '''
        :param edits: input edits that will be divided into independent or epistasis group
        :returns: independent edits and epistasis edits
        '''
        log.info("Start searching for indepedent/epistasis edits ...")
        independentEdits = []
        for edit in edits:
            try:
                editOnlyInd = self.toolbox.individual(edits=independentEdits+[edit])
            except irind.llvmIRrepRuntimeError:
                continue
            fitness_values = [self.evaluate(editOnlyInd) for i in range(3)]
            if None in [value[0] for value in fitness_values]:
                continue
            fit = min([value[0] for value in fitness_values])
            err = min([value[1] for value in fitness_values])
            editOnlyInd.fitness.values = (fit, err)
            runtimeDiffDown = self.origin.fitness.values[0] - editOnlyInd.fitness.values[0]

            try:
                fullExceptEditInd = self.toolbox.individual(edits=[ e for e in edits if e not in editOnlyInd.edits])
            except irind.llvmIRrepRuntimeError:
                continue
            fitness_values = [self.evaluate(fullExceptEditInd) for i in range(3)]
            if None in [value[0] for value in fitness_values]:
                continue
            fit = min([value[0] for value in fitness_values])
            err = min([value[1] for value in fitness_values])
            fullExceptEditInd.fitness.values = (fit, err)
            runtimeDiffTop = fullExceptEditInd.fitness.values[0] - self.fullEditsInd.fitness.values[0]
            
            if abs(runtimeDiffDown - runtimeDiffTop) < self.origin.fitness.values[0]*0.01:
                log.info(f"{edit} can be independently applied: {editOnlyInd.fitness.values[0]:.2f}.")
                independentEdits.append(edit)

        epistasis = [ edit for edit in edits if edit not in independentEdits ]
        with open("reduced_no_independent.edit", 'w') as f:
            rprint(epistasis, file=f)
        with open("reduced_independent.edit", 'w') as f:
            rprint(independentEdits, file=f)
        return independentEdits, epistasis

    def evolve_path_searching(self, edits):
        editIdxMap = {edits_as_key([edit]): cnt for cnt, edit in enumerate(edits)}

        grid = Table.grid()
        grid.add_column(justify="right", style="bold blue")
        grid.add_column()
        for cnt, edit in enumerate(edits):
            grid.add_row(str(cnt)+': ', str(edit))
        rprint(grid)
        
        fcomb = open("path_find.txt", 'w')
        rprint(grid, file=fcomb)

        best_comb = None
        best_imp = 0
        
        with Progress(auto_refresh=False) as pbar:
            task1 = pbar.add_task(f"", total=len(list(itertools.combinations(edits, 2))))
            cnt = 0
            for subEdits in itertools.combinations(edits, 2): 
                fcomb.flush()
                pbar.update(task1, completed=cnt, refresh=True,
                            description=f"{2} combination: ({cnt}/{len(list(itertools.combinations(edits, 2)))})")
                cnt = cnt + 1

                edit_str = ' '.join([ str(editIdxMap[edits_as_key([edit])]) for edit in subEdits ])
                try:
                    subEditsInd = self.toolbox.individual(edits=subEdits)
                except irind.llvmIRrepRuntimeError:
                    continue
                fitness_values = [self.evaluate(subEditsInd) for i in range(3)]
                if None in [value[0] for value in fitness_values]:
                    continue
                fit = min([value[0] for value in fitness_values])
                err = min([value[1] for value in fitness_values])
                subEditsInd.fitness.values = (fit, err)
                improvement = self.origin.fitness.values[0] / fit
                if improvement < 1.01:
                    continue

                rprint(f'[ {edit_str}]: ({fit:.2f}, {err:.2f}), Imp:{improvement:.2f}')
                fcomb.write(f'{edit_str},{fit:.2f},{err:.2f},{improvement:.2f}\n')

                if improvement > best_imp:
                    best_imp = improvement
                    best_comb = subEdits
            rprint(f'Best comb in 2: {best_comb}: {best_imp:.2f}')
            fcomb.write(f'Best comb in 2: {best_comb}: {best_imp:.2f}\n')
        
        for l in range(3, len(edits)):
            rest_edits = [ e for e in edits if e not in best_comb ]
            fcomb.write(f"{l} combinations\n")
            cur_best_comb = list(best_comb)
            with Progress(auto_refresh=False) as pbar:
                task1 = pbar.add_task(f"", total=len(rest_edits))
                cnt = 0
                for edit in rest_edits:
                    subEdits = list(best_comb)
                    subEdits.append(edit)
                    fcomb.flush()
                    best_edit_str = ' '.join([ str(editIdxMap[edits_as_key([e])]) for e in cur_best_comb ])
                    pbar.update(task1, completed=cnt, refresh=True,
                                description=f"{l} combination(current best:[{best_edit_str}]): ({cnt}/{len(rest_edits)})")
                    cnt = cnt + 1

                    edit_str = ' '.join([ str(editIdxMap[edits_as_key([e])]) for e in subEdits ])
                    try:
                        subEditsInd = self.toolbox.individual(edits=subEdits)
                    except irind.llvmIRrepRuntimeError:
                        continue
                    fitness_values = [self.evaluate(subEditsInd) for i in range(3)]
                    if None in [value[0] for value in fitness_values]:
                        continue
                    fit = min([value[0] for value in fitness_values])
                    err = min([value[1] for value in fitness_values])
                    subEditsInd.fitness.values = (fit, err)
                    improvement = self.origin.fitness.values[0] / fit
                    if improvement < 1.01:
                        continue

                    rprint(f'[ {edit_str}]: ({fit:.2f}, {err:.2f}), Imp:{improvement:.2f}')
                    fcomb.write(f'{edit_str},{fit:.2f},{err:.2f},{improvement:.2f}\n')

                    if improvement > best_imp:
                        best_imp = improvement
                        cur_best_comb = subEdits

                best_comb = cur_best_comb
            
            rprint(f'Best comb in {l}: {best_comb}: {best_imp:.2f}')
            fcomb.write(f'Best comb in 2 {best_comb}: {best_imp:.2f}\n')

        fcomb.close()
        
    
    def group_test(self, edits):
        log.info("Start evaluating all edit combinations iteratively ...")
        editIdxMap = {edits_as_key([edit]): cnt for cnt, edit in enumerate(edits)}
        
        grid = Table.grid()
        grid.add_column(justify="right", style="bold blue")
        grid.add_column()
        for cnt, edit in enumerate(edits):
            grid.add_row(str(cnt)+': ', str(edit))
        rprint(grid)
        
        fcomb = open("group_test.txt", 'w')
        rprint(grid, file=fcomb)

        for l in range(2, len(edits)):
            fcomb.write(f"{l} combinations\n")
            with Progress(auto_refresh=False) as pbar:
                task1 = pbar.add_task(f"", total=len(list(itertools.combinations(edits, l))))
                cnt = 0
                for subEdits in itertools.combinations(edits, l): 
                    fcomb.flush()
                    pbar.update(task1, completed=cnt, refresh=True,
                                description=f"{l} combination: ({cnt}/{len(list(itertools.combinations(edits, l)))})")
                    cnt = cnt + 1

                    edit_str = ' '.join([ str(editIdxMap[edits_as_key([edit])]) for edit in subEdits ])
                    try:
                        subEditsInd = self.toolbox.individual(edits=subEdits)
                    except irind.llvmIRrepRuntimeError:
                        fcomb.write(f'{edit_str},c,c,c\n')
                        continue
                    fitness_values = [self.evaluate(subEditsInd) for i in range(3)]
                    if None in [value[0] for value in fitness_values]:
                        fcomb.write(f'{edit_str},x,x,x\n')
                        continue
                    fit = min([value[0] for value in fitness_values])
                    err = min([value[1] for value in fitness_values])
                    subEditsInd.fitness.values = (fit, err)
                    improvement = self.origin.fitness.values[0] / fit
                    if improvement < 1.01:
                        fcomb.write(f'{edit_str},{fit:.2f},{err:.2f},{improvement:.2f}\n')
                        continue

                    rprint(f'[ {edit_str}]: ({fit:.2f}, {err:.2f}), Imp:{improvement:.2f}')
                    fcomb.write(f'{edit_str},{fit:.2f},{err:.2f},{improvement:.2f}\n')

        fcomb.close()

    def uniedit_depedence(self,edits):
        editIdxMap = {edits_as_key([edit]): cnt for cnt, edit in enumerate(edits)}
        
        grid = Table.grid()
        grid.add_column(justify="right", style="bold blue")
        grid.add_column()
        for cnt, edit in enumerate(edits):
            grid.add_row(str(cnt)+': ', str(edit))
        rprint(grid)
        
        uniedit_str = input("uniedit index: ")
        try:
            uniedit = edits[int(uniedit_str)]
            rprint(f'Target - {uniedit_str}: {uniedit}')
        except:
            print (f'{uniedit_str} is not a valid input')

        unitest_status = True
        try:
            testInd = self.toolbox.individual(edits=[uniedit])
        except:
            unitest_status = False

        if unitest_status:
            fitness_values = [self.evaluate(testInd) for i in range(3)]
            if None in [value[0] for value in fitness_values]:
                unitest_status = False

        # Only fail in applying uniedit alone need to do the dependency test
        if unitest_status:
            rprint(f'{uniedit_str}: {uniedit} is independently passable!')
            return

        with Progress(auto_refresh=False) as pbar:
            task1 = pbar.add_task('', total=len(edits)-1)
            for cnt, edit in enumerate(edits):
                if edit == uniedit:
                    continue
                pbar.update(task1, completed=cnt, refresh=True,
                            description=f"{cnt} testing ...")
                try:
                    testInd = self.toolbox.individual(edits=[edit, uniedit])
                except irind.llvmIRrepRuntimeError:
                    rprint(f'{cnt}:{edit} - Complilation failed')
                    continue

                fitness_values = [self.evaluate(testInd) for i in range(3)]
                if None in [value[0] for value in fitness_values]:
                    rprint(f'{cnt}:{edit} - Verification failed')
                    continue

                fit = min([value[0] for value in fitness_values])
                improvement = self.origin.fitness.values[0] / fit
                rprint(f'{cnt}:{edit} - {improvement}')
                
            

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Analyze the performance of mutation edits for CUDA kernel")
    parser.add_argument('-P', '--profile_file', type=str, required=True,
        help="Specify the profile file that contains all application execution and testing information")
    parser.add_argument('-e', '--edit', type=str, required=True,
        help="The edit file")
    parser.add_argument('-t', '--timeout', type=int, default=30,
        help="The timeout period to evaluate the CUDA application")
    parser.add_argument('-fitf', '--fitness_function', type=str, default='time',
        help="What is the target fitness for the evolution. Default ot execution time. Can be changed to power")
    parser.add_argument('--editmap', type=str,
        help="Loading the pre-generated edit-fit map to avoid previously tested edit set")
    parser.add_argument('--err_rate', type=str, default='0.01',
        help="Allowed maximum relative error generate from mutant comparing to the origin")
    parser.add_argument('--version', action='version', version='gevo-' + __version__)
    args = parser.parse_args()

    try:
        profile = json.load(open(args.profile_file))
    except FileNotFoundError:
        log.error(f"The profile:'{args.profile_file}' cannot be found")
    except:
        print(sys.exc_info())
        exit(-1)

    alyz = program(
        editf=args.edit,
        kernel=profile['kernels'],
        bin=profile['binary'],
        profile=profile,
        timeout=args.timeout,
        fitness=args.fitness_function,
        editmap=args.editmap,
        err_rate=args.err_rate)

    table = Table.grid(expand=True)
    table.add_column(justify="right", style="bold blue")
    table.add_column()
    table.add_row("Target CUDA program: ", profile['binary'])
    tc_args = ""
    for tc in alyz.testcase:
        tc_args = tc_args + "{}".format(" ".join(tc.args)) + '\n'
    table.add_row("Args for the CUDA program:: ", tc_args)
    table.add_row("Target kernels:: ", " ".join(profile['kernels']))
    table.add_row("Evaluation Timeout:: ", str(args.timeout))
    table.add_row("Fitness function:: ", args.fitness_function)
    table.add_row("Edit file:: ", args.edit)
    table.add_row("Tolerate Error Rate:: ", str(args.err_rate))
    rprint(table)

    try:
        curEdits = alyz.edits
        while True:
            rprint("0) list current edits")
            rprint("1) test each edit individualy")
            rprint("2) remove weak edits")
            rprint("3) group independent/epistasis edits")
            rprint("4) test edit combinations exhaustively")
            rprint("5) test uniedit's dependency")
            rprint("6) search the evolving path from given edits")
            op = input("Chose the analysis operation you want to go over: ")

            if op == '0':
                rprint(curEdits)
            elif op == '1':
                alyz.edittest(curEdits)
            elif op == '2':
                curEdits = alyz.remove_weak_edits(curEdits)
            elif op == '3':
                groupOp = input("return (1)independent or (2)epistasis as current edits after grouping?")
                indEdits, epistasisEdits = alyz.search_indepedent_edits(curEdits)
                curEdits = indEdits if groupOp == 1 else epistasisEdits
            elif op == '4':
                alyz.group_test(curEdits)
            elif op == '5':
                alyz.uniedit_depedence(curEdits)
            elif op == '6':
                alyz.evolve_path_searching(curEdits)
            else:
                log.warning(f"Invalid selection: {op}")
    except KeyboardInterrupt:
        subprocess.run(['killall', profile['binary']])
