#!/usr/bin/python3

import argparse
import shlex
from itertools import repeat
from multiprocessing.pool import ThreadPool

from colorama import Fore, init

from dfs.df_client import *


def info() -> None:
    print()
    print(f"{Fore.GREEN}Welcome to DFS REPL")
    print(f"{Fore.BLUE}author: Brian Guarraci")
    print(f"{Fore.BLUE}repo  : https://github.com/briangu/dfs")
    print(f"{Fore.YELLOW}crtl-c to quit")
    print()


init(autoreset=True)


success = lambda input: f"{Fore.GREEN}{input}"
failure = lambda input: f"{Fore.RED}{input}"


def stats_argparser():
    parser = argparse.ArgumentParser(description='Run stats command', prog="stats")
    parser.add_argument('--level', type=int, help='specify alternate stats level (default: 0)', default=0)
    return parser


def execute_stats_cmd(pool, args):
    with pool.get_connection() as c:
        stats = c.get_stats(level=args.level)
    return json.dumps(stats, indent=2)


def scan_file(pool, *args):
    try:
        with pool.get_connection() as c:
            start_t = time.time_ns()
            tinfo(f"scanning: {args[0]}")
            data = c.load(*args[0])
            stop_t = time.time_ns()
            return (stop_t - start_t)/(10**9), data['length'] / 1024
    except ConnectionError as e:
        print(f"failed to connect")
        return (0, 0)
    except Exception as e:
        import traceback
        print(e)
        traceback.print_exc(e)
        return (0, 0)


def execute_scan_cmd(pool, args):
    with pool.get_connection() as c:
        stats = c.get_stats(level=2)
    files = stats['all_files']

    print(f"scanning {len(files)} files...")

    with ThreadPool() as p:
        start_t = time.time()
        timings = p.starmap(scan_file, zip(repeat(pool), files))
        stop_t = time.time()
        time_s = sum([x[0] for x in timings])
        data_s = sum([x[1] for x in timings])
        if time_s > 0:
            return f"{int((data_s / time_s) / 1000)} MB/s  {(stop_t - start_t)/len(files)} s/file"
        return "no files"


def load_dataframe(file_path):
    print(f"load_dataframe: {file_path}")
    return pd.read_pickle(file_path)


def send_file(pool, file_path, *args):
    print(f"send_file: {file_path} {args}")
    with pool.get_connection() as c:
        c.insert_data(load_dataframe(file_path), *args[0])


def load_argparser():
    parser = argparse.ArgumentParser(description='Run load command', prog="load")
    parser.add_argument('--dir', type=str, help='specify alternate files directory (default: current dir)', default=os.getcwd())
    return parser


def execute_load_cmd(pool, args):
    files_arr = []
    arg_arr = []
    for path,_,files in os.walk(args.dir):
        for f in files:
            sp, se = os.path.splitext(f)
            j = os.path.join(path, f)
            if not (se == '.pkl' and os.path.isfile(j)):
                continue
            files_arr.append(j)
            arg_arr.append(os.path.join(path[len(args.dir):], sp).split(os.sep))
    print(files_arr)
    print(arg_arr)
    with ThreadPool() as p:
        p.starmap(send_file, zip(repeat(pool), files_arr, arg_arr))

    return f"loaded {len(files_arr)} files"


def compare_dfs(stats, df, dfs_root_path, *args):
    m = df_memory_usage(df)
    print(f"{args} stats: {stats['memory']['used']} df: {m}")
    if int(stats['memory']['used']) != int(m):
        gzip_fname = os.path.join(dfs_root_path, get_file_path_from_key_path(*args))
        fdf = read_df(gzip_fname)
        print(len(df), len(fdf), df.equals(fdf))
    assert int(stats['memory']['used']) == int(m)


def verify_file(pool, dfs_root_path, *args):
    with pool.get_connection() as c:
        c.load(*args)
        stats = c.get_stats()
        df = c.get_data(*args)
        compare_dfs(stats, df, dfs_root_path, *args)
        c.unload(*args)


def unload_all_files(pool):
    with pool.get_connection() as c:
        stats = c.get_stats(level=2)
        for key,_ in stats['loaded_files']:
            c.unload(*key)
        stats = c.get_stats(level=0)
        assert int(stats['memory']['used']) == 0


def verify_argparser():
    parser = argparse.ArgumentParser(description='Run verify command', prog="verify")
    parser.add_argument('--dir', type=str, help='specify alternate files directory (default: current dir)', default=os.getcwd())
    return parser


def execute_verify_cmd(pool, args):

    unload_all_files(pool)

    with pool.get_connection() as c:
        stats = c.get_stats(level=2)

    for key in stats['all_files']:
        verify_file(pool, args.dir, *key)

    return f"verified {len(stats['all_files'])} files"


commands = {
    "stats": (stats_argparser(), execute_stats_cmd),
    "verify": (verify_argparser(), execute_verify_cmd),
    "scan": (None, execute_scan_cmd),
    "load": (load_argparser(), execute_load_cmd)
}


class UnknownCommandError(Exception):
    pass


def execute_cmd(pool, cmd_arr):
    cmd = commands.get(cmd_arr[0])
    if cmd is None:
        raise UnknownCommandError(f"Available commands: {list(commands.keys())}")
    return cmd[1](pool, None if cmd[0] is None else cmd[0].parse_args(cmd_arr[1:]))


# https://dev.to/amal/building-the-python-repl-3468
def repl(pool):
    info()
    try:
        while True:
            try:
                p = input("?> ")
                if len(p) == 0:
                    continue
                cmd_arr = shlex.split(p)
                if len(cmd_arr) == 0:
                    continue
                if cmd_arr[0] == 'exit':
                    break
                r = execute_cmd(pool, cmd_arr)
                print(success(r))
            except Exception as e:
                print(failure(f"Error: {e.args}"))
    except KeyboardInterrupt:
        print("\nExiting...")


parser = argparse.ArgumentParser(description='Run DataFrame Service command')
parser.add_argument('--port', type=int, help='specify alternate port (default: 8000)', default=8000)
parser.add_argument('--host', type=str, help='specify alternate host address (default: 127.0.0.1)', default="127.0.0.1")
parser.add_argument('--max_connections', type=int, help='specify alternate source data directory (default: 8)', default=8)
parser.add_argument('--log', type=str, help='specify alternate logging level (default: WARN)', default="WARN")
args = parser.parse_args()

log_level = getattr(logging, args.log.upper(), None)
if not isinstance(log_level, int):
    raise ValueError('Invalid log level: %s' % args.log)
logging.basicConfig(level=log_level)

with DataFrameConnectionPool(args.host, args.port, max_connections=args.max_connections) as pool:
    repl(pool)

