#!/usr/bin/env python3
# pylint: disable=logging-fstring-interpolation
# Under experiment/development. Do not use if you don't know what it is doing


import argparse
import subprocess
import json
import sys
import logging
import ast
from copy import deepcopy

from rich.logging import RichHandler
from rich.progress import Progress
from rich.table import Table
from rich import print as rprint

from gevo import __version__
from gevo.evolve import evolution

logging.basicConfig(format="%(message)s" ,level="NOTSET" ,handlers=[RichHandler()])
log=logging.getLogger("main")

class program(evolution):
    def __init__(self, editf, kernel, bin, profile, timeout=30, fitness='time',
                 llvm_src_filename='cuda-device-only-kernel.ll', err_rate=0.01):
        super().__init__(
            kernel=kernel,
            bin=bin,
            profile=profile,
            timeout=timeout,
            fitness=fitness,
            err_rate=err_rate,
            mutop='',
            use_fitness_map=False )

        try:
            with open(editf, 'r') as f:
                self.edits = ast.literal_eval(f.read())
        except FileNotFoundError:
            log.error(f"Edit File:{editf} cannot be found")
            sys.exit(1)

    def evaluate_full_edits(self):
        print(f"Evaluate edit file", end="", flush=True)
        self.fullEditsInd = self.toolbox.individual()
        self.fullEditsInd.edits = self.edits
        if self.fullEditsInd.update_from_edits() is False:
            raise Exception("Edit file cannot be compiled")
        fitness_values = [self.evaluate(self.fullEditsInd) for i in range(3)]
        fit = min([value[0] for value in fitness_values])
        err = min([value[1] for value in fitness_values])
        self.fullEditsInd.fitness.values = (fit, err)
        log.info(f"Fitness of the program with all edits: {self.fullEditsInd.fitness}")

    def remove_useless_edits(self):
        self.pop = self.toolbox.population(n=len(self.edits))
        with Progress(auto_refresh=False) as pbar:
            task1 = pbar.add_task("", total=len(self.edits))
            pbar.update(task1, completed=0, refresh=True,
                        description=f"(0/{len(self.edits)})")

            removal_list = []
            for cnt, (ind,edit) in enumerate(zip(self.pop, self.edits)):
                ind.edits = self.edits.copy()
                ind.edits.remove(edit)
                # print(f"Evalute edit removal: {edit}", end='', flush=True)
                if ind.update_from_edits() == False:
                    log.info(f"{edit[0]} removed: cannot compile")
                    continue
                fitness_values = [self.evaluate(ind) for i in range(3)]
                if None in [value[0] for value in fitness_values]:
                    log.info(f"{edit[0]} removed: execution failed")
                    continue

                fit = min([value[0] for value in fitness_values])
                err = min([value[1] for value in fitness_values])
                improvement = self.fullEditsInd.fitness.values[0]/fit
                log.info(f"{edit[0]} removed: {fit:.2f}. Improvement: {improvement:.2f}. Error:{err}")
                if improvement > 0.99:
                    removal_list.append(edit)

                pbar.update(task1, completed=cnt+1, refresh=True,
                            description=f"({cnt+1}/{len(self.edits)})")
            for edit in removal_list:
                self.fullEditsInd.edits.remove(edit)
            if self.fullEditsInd.update_from_edits() is False:
                log.error("Final reduced: cannot be compiled")
            fitness_values = [self.evaluate(self.fullEditsInd) for i in range(3)]
            if None in [value[0] for value in fitness_values]:
                log.error("Final reduced: execution failed")
                
            fit = min([value[0] for value in fitness_values])
            err = min([value[1] for value in fitness_values])
            self.fullEditsInd.fitness.values = (fit, err)
            log.info(f"Fitness of the edit-reduced program: {self.fullEditsInd.fitness}")

            with open("reduced.edit", "w") as f:
                rprint(self.edits, file=f)
            with open("reduced.ll", "w") as f:
                f.write(self.fullEditsInd.srcEnc.decode())

    def search_indepedent_edits(self):
        self.independentEdits = []
        for edit in self.edits:
            editOnlyInd = self.toolbox.individual()
            editOnlyInd.edits = [edit]
            if editOnlyInd.update_from_edits() is False:
                continue
            fitness_values = [self.evaluate(editOnlyInd) for i in range(3)]
            if None in [value[0] for value in fitness_values]:
                continue
            fit = min([value[0] for value in fitness_values])
            err = min([value[1] for value in fitness_values])
            editOnlyInd.fitness.values = (fit, err)
            runtimeDiffDown = self.origin.fitness.values[0] - editOnlyInd.fitness.values[0]

            fullExceptEditInd = self.toolbox.individual()
            fullExceptEditInd.edits = deepcopy(self.edits)
            fullExceptEditInd.edits.remove(edit)
            if fullExceptEditInd.update_from_edits() is False:
                continue
            fitness_values = [self.evaluate(fullExceptEditInd) for i in range(3)]
            if None in [value[0] for value in fitness_values]:
                continue
            fit = min([value[0] for value in fitness_values])
            err = min([value[1] for value in fitness_values])
            fullExceptEditInd.fitness.values = (fit, err)
            runtimeDiffTop = fullExceptEditInd.fitness.values[0] - self.fullEditsInd.fitness.values[0]
            
            if abs(runtimeDiffDown - runtimeDiffTop) < self.origin.fitness.values[0]*0.01:
                log.info("{} can be independently applied".format(edit))
                self.independentEdits.append(edit)

        reducedNoIndependent = [ edit for edit in self.edits if edit not in self.independentEdits ]
        with open(f"reduced_no_independent.edit", 'w') as f:
            rprint(reducedNoIndependent, file=f)
        with open(f"reduced_independent.edit", 'w') as f:
            rprint(self.independentEdits, file=f)
        
    
    def group_epistasis(self):
        # Make sure only essential edits (useful in terms of performance) are included
        allEdits = deepcopy(self.edits)
        groups = []
        for editLead in self.edits:
            if editLead not in allEdits:
                continue
            indWoEditLead = self.toolbox.individual()
            indWoEditLead.edits = deepcopy(allEdits)
            indWoEditLead.edits.remove(editLead)
            if indWoEditLead.update_from_edits() is False:
                continue
            fitness_values = [self.evaluate(indWoEditLead) for i in range(3)]
            if None in [value[0] for value in fitness_values]:
                continue

            if not groups:
                groups.append([editLead])
            else:
                if editLead not in [ e for g in groups for e in g ]:
                    groups.append([editLead])
                else:
                    continue

            fit = min([value[0] for value in fitness_values])
            err = min([value[1] for value in fitness_values])
            indWoEditLead.fitness.values = (fit, err)

            for otherEdit in indWoEditLead.edits:
                indWoEditLeadOtherEdit = self.toolbox.individual()
                indWoEditLeadOtherEdit.edits = deepcopy(indWoEditLead.edits)
                indWoEditLeadOtherEdit.edits.remove(otherEdit)

                if indWoEditLeadOtherEdit.update_from_edits() is False:
                    continue
                fitness_values = [self.evaluate(indWoEditLeadOtherEdit) for i in range(3)]
                if None in [value[0] for value in fitness_values]:
                    continue

                fit = min([value[0] for value in fitness_values])
                err = min([value[1] for value in fitness_values])
                indWoEditLeadOtherEdit.fitness.values = (fit, err)

                runtimeRatio = indWoEditLead.fitness.values[0] / indWoEditLeadOtherEdit.fitness.values[0]
                if runtimeRatio > 0.99 and runtimeRatio < 1.01:
                    groups[-1].append(otherEdit)
                    allEdits.remove(otherEdit)

            allEdits.remove(editLead)

        for cnt, group in enumerate(groups):
            with open(f"epistasis_{str(cnt)}.edit", 'w') as f:
                rprint(group, file=f)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Find and remove the edit that do nothing from the given edits")
    parser.add_argument('-P', '--profile_file', type=str, required=True,
        help="Specify the profile file that contains all application execution and testing information")
    parser.add_argument('-e', '--edit', type=str, required=True,
        help="The edit file")
    parser.add_argument('-t', '--timeout', type=int, default=30,
        help="The timeout period to evaluate the CUDA application")
    parser.add_argument('-fitf', '--fitness_function', type=str, default='time',
        help="What is the target fitness for the evolution. Default ot execution time. Can be changed to power")
    parser.add_argument('--err_rate', type=float, default='0.01',
        help="Allowed maximum relative error generate from mutant comparing to the origin")
    parser.add_argument('--version', action='version', version='gevo-' + __version__)
    args = parser.parse_args()

    try:
        profile = json.load(open(args.profile_file))
    except FileNotFoundError:
        log.error(f"The profile:'{args.profile_file}' cannot be found")
    except:
        print(sys.exc_info())
        exit(-1)

    alyz = program(
        editf=args.edit,
        kernel=profile['kernels'],
        bin=profile['binary'],
        profile=profile,
        timeout=args.timeout,
        fitness=args.fitness_function,
        err_rate=args.err_rate)

    table = Table.grid(expand=True)
    table.add_column(justify="right", style="bold blue")
    table.add_column()
    table.add_row("Target CUDA program: ", profile['binary'])
    tc_args = ""
    for tc in alyz.testcase:
        tc_args = tc_args + "{}".format(" ".join(tc.args)) + '\n'
    table.add_row("Args for the CUDA program:: ", tc_args)
    table.add_row("Target kernels:: ", " ".join(profile['kernels']))
    table.add_row("Evaluation Timeout:: ", str(args.timeout))
    table.add_row("Fitness function:: ", args.fitness_function)
    table.add_row("Edit file:: ", args.edit)
    table.add_row("Tolerate Error Rate:: ", str(args.err_rate))
    rprint(table)

    try:
        alyz.evaluate_full_edits()

        # alyz.remove_useless_edits()
        # alyz.search_indepedent_edits()
        alyz.group_epistasis()
    except KeyboardInterrupt:
        subprocess.run(['killall', profile['binary']])
