#!/usr/bin/env python3
# pylint: disable=logging-fstring-interpolation
import argparse
import subprocess
import json
import sys
import logging
import ast

from rich.logging import RichHandler
from rich.progress import Progress
from rich.table import Table
from rich import print as rprint

from gevo.evolve import evolution

logging.basicConfig(format="%(message)s" ,level="NOTSET" ,handlers=[RichHandler()])
log=logging.getLogger("main")

class program(evolution):
    def __init__(self, editf, kernel, bin, profile, timeout=30, fitness='time',
                 llvm_src_filename='cuda-device-only-kernel.ll', err_rate=0.01):
        super().__init__(
            kernel=kernel,
            bin=bin,
            profile=profile,
            timeout=timeout,
            fitness=fitness,
            err_rate=err_rate,
            mutop='',
            use_fitness_map=False )

        try:
            with open(editf, 'r') as f:
                self.edits = ast.literal_eval(f.read())
        except FileNotFoundError:
            log.error(f"Edit File:{editf} cannot be found")
            sys.exit(1)

        # Redefine origin as the base LLVM program with full edits
        self.origin.edits = self.edits
        if self.origin.update_from_edits() is False:
            raise Exception("Edit file cannot be compiled")
        fitness_values = [self.evaluate(self.origin) for i in range(3)]
        fit = min([value[0] for value in fitness_values])
        err = min([value[1] for value in fitness_values])
        self.origin.fitness.values = (fit, err)
        log.info(f"Fitness of the program with all edits: {self.origin.fitness}")

    def edittest(self):
        self.pop = self.toolbox.population(n=len(self.edits))
        with Progress(auto_refresh=False) as pbar:
            task1 = pbar.add_task("", total=len(self.edits))
            pbar.update(task1, completed=0, refresh=True,
                        description=f"(0/{len(self.edits)})")

            removal_list = []
            for cnt, (ind,edit) in enumerate(zip(self.pop, self.edits)):
                ind.edits = self.edits.copy()
                ind.edits.remove(edit)
                # print(f"Evalute edit removal: {edit}", end='', flush=True)
                if ind.update_from_edits() == False:
                    log.info(f"{edit[0]} removed: cannot compile")
                    continue
                fitness_values = [self.evaluate(ind) for i in range(3)]
                if None in [value[0] for value in fitness_values]:
                    log.info(f"{edit[0]} removed: execution failed")
                    continue

                fit = min([value[0] for value in fitness_values])
                err = min([value[1] for value in fitness_values])
                improvement = self.origin.fitness.values[0]/fit
                log.info(f"{edit[0]} removed: {fit:.2f}. Improvement: {improvement:.2f}. Error:{err}")
                if improvement > 0.99:
                    removal_list.append(edit)

                pbar.update(task1, completed=cnt+1, refresh=True,
                            description=f"({cnt+1}/{len(self.edits)})")

            for edit in removal_list:
                self.origin.edits.remove(edit)
            if self.origin.update_from_edits() is False:
                log.error("Final reduced: cannot be compiled")
            fitness_values = [self.evaluate(self.origin) for i in range(3)]
            if None in [value[0] for value in fitness_values]:
                log.error("Final reduced: execution failed")
                
            fit = min([value[0] for value in fitness_values])
            err = min([value[1] for value in fitness_values])
            self.origin.fitness.values = (fit, err)
            log.info(f"Fitness of the edit reduced program: {self.origin.fitness}")

            with open("reduced.edit", "w") as f:
                rprint(self.edits, file=f)
            with open("reduced.ll", "w") as f:
                f.write(self.origin.srcEnc.decode())

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Find and remove the edit that do nothing from the given edits")
    parser.add_argument('-P', '--profile_file', type=str, required=True,
        help="Specify the profile file that contains all application execution and testing information")
    parser.add_argument('-e', '--edit', type=str, required=True,
        help="The edit file")
    parser.add_argument('-t', '--timeout', type=int, default=30,
        help="The timeout period to evaluate the CUDA application")
    parser.add_argument('-fitf', '--fitness_function', type=str, default='time',
        help="What is the target fitness for the evolution. Default ot execution time. Can be changed to power")
    parser.add_argument('--err_rate', type=float, default='0.01',
        help="Allowed maximum relative error generate from mutant comparing to the origin")
    args = parser.parse_args()

    try:
        profile = json.load(open(args.profile_file))
    except FileNotFoundError:
        log.error(f"The profile:'{args.profile_file}' cannot be found")
    except:
        print(sys.exc_info())
        exit(-1)

    alyz = program(
        editf=args.edit,
        kernel=profile['kernels'],
        bin=profile['binary'],
        profile=profile,
        timeout=args.timeout,
        fitness=args.fitness_function,
        err_rate=args.err_rate)

    table = Table.grid()
    table.add_column(justify="right", style="bold blue")
    table.add_column()
    table.add_row("Target CUDA program: ", profile['binary'])
    tc_args = ""
    for tc in alyz.testcase:
        tc_args = tc_args + "{}".format(" ".join(tc.args)) + '\n'
    table.add_row("Args for the CUDA program: ", tc_args)
    table.add_row("Target kernels: ", " ".join(profile['kernels']))
    table.add_row("Evaluation Timeout: ", str(args.timeout))
    table.add_row("Fitness function: ", args.fitness_function)
    table.add_row("Edit file: ", args.edit)
    table.add_row("Tolerate Error Rate: ", str(args.err_rate))
    rprint(table)

    # print("      Target CUDA program: {}".format(profile['binary']))
    # print("Args for the CUDA program:")
    # for tc in alyz.testcase:
    #     print("\t{}".format(" ".join(tc.args)))
    # print("           Target kernels: {}".format(" ".join(profile['kernels'])))
    # print("       Evaluation Timeout: {}".format(args.timeout))
    # print("         Fitness function: {}".format(args.fitness_function))
    # print("                Edit file: {}".format(args.edit))
    # print("      Tolerate Error Rate: {}".format(args.err_rate))

    try:
        alyz.edittest()
    except KeyboardInterrupt:
        subprocess.run(['killall', profile['binary']])
