#!/usr/bin/python3

import argparse
import os
import shlex
from itertools import repeat
from multiprocessing.pool import ThreadPool

from colorama import Fore, init

from dfs.df_client import *


def exec_stats_cmd(pool, level=0):
    with pool.get_connection() as c:
        stats = c.get_stats(level=level)
    return json.dumps(stats, indent=2)


def scan_file(pool, *args):
    try:
        with pool.get_connection() as c:
            start_t = time.time_ns()
            tinfo(f"scanning: {args[0]}")
            data = c.load(*args[0])
            stop_t = time.time_ns()
            return (stop_t - start_t)/(10**9), data['length'] / 1024
    except ConnectionError as e:
        print(f"failed to connect")
        return (0, 0)
    except Exception as e:
        import traceback
        print(e, args)
        # traceback.print_exc(e)
        return (0, 0)


def exec_scan_cmd(pool, args):
    with pool.get_connection() as c:
        stats = c.get_stats(level=2)
    keys = stats['all_keys']

    print(f"scanning {len(keys)} keys...")

    with ThreadPool() as p:
        start_t = time.time()
        timings = p.starmap(scan_file, zip(repeat(pool), keys))

        stop_t = time.time()
        time_s = sum([x[0] for x in timings])
        data_s = sum([x[1] for x in timings])
        if time_s > 0:
            return f"{int((data_s / time_s) / 1000)} MB/s  {(stop_t - start_t)/len(keys)} s/file"
        return "no keys"


def exec_ls_cmd(pool, args):
    with pool.get_connection() as c:
        stats = c.get_stats(level=2)
    for key_path in stats['all_keys']:
        print(os.sep.join(key_path))


def read_dataframe_file(file_path):
    print(f"import_dataframe: {file_path}")
    return pd.read_pickle(file_path)


def read_file(file_path):
    with open(file_path, "rb") as f:
        return f.read()


def set_file(pool, file_path, *args):
    with pool.get_connection(FileClient) as c:
        _, se = os.path.splitext(file_path)
        print(se)
        if se == ".pkl":
            c.set(serialize_df(read_dataframe_file(file_path)), *args)
        else:
            c.set(read_file(file_path), *args)


def get_file(pool, file_path, *args):
    with pool.get_connection(FileClient) as c:
        with open(file_path, "wb") as f:
            f.write(c.get(*args))


def exec_set_cmd(pool, args):
    set_file(pool, args.path, *args.key.split(os.sep))
    return f"set {args.path} as {args.key}"


def exec_import_cmd(pool, args):
    files_arr = []
    arg_arr = []
    for path,_,files in os.walk(args.dir):
        for f in files:
            sp, se = os.path.splitext(f)
            j = os.path.join(path, f)
            if not os.path.isfile(j):
                continue
            files_arr.append(j)
            arg_arr.append((os.path.join(path[len(args.dir):], sp).split(os.sep), se))
    with ThreadPool() as p:
        p.starmap(set_file, zip(repeat(pool), files_arr, arg_arr))

    return f"imported {len(files_arr)} files"


def exec_get_cmd(pool, args):
    get_file(pool, args.path, *args.key.split(os.sep))
    return f"get {args.key} as {args.path}"


def compare_dfs(stats, df, dfs_root_path, *args):
    m = df_memory_usage(df)
    print(f"{args} stats: {stats['memory']['used']} df: {m}")
    if int(stats['memory']['used']) != int(m):
        gzip_fname = os.path.join(dfs_root_path, os.path.join(*args))
        fdf = read_df(gzip_fname)
        print(len(df), len(fdf), df.equals(fdf))
    assert int(stats['memory']['used']) == int(m)


def verify_file(pool, dfs_root_path, *args):
    with pool.get_connection() as c:
        c.load(*args)
        stats = c.get_stats()
        df = c.get_data(*args)
        compare_dfs(stats, df, dfs_root_path, *args)
        c.unload(*args)


def unload_all_keys(pool):
    with pool.get_connection() as c:
        stats = c.get_stats(level=2)
        unloaded_cnt = len(stats['loaded_keys'])
        for key,_ in stats['loaded_keys']:
            c.unload(*key)
        stats = c.get_stats(level=0)
        assert int(stats['memory']['used']) == 0
    return unloaded_cnt


def exec_verify_cmd(pool, args):

    unload_all_keys(pool)

    with pool.get_connection() as c:
        stats = c.get_stats(level=2)

    for key in stats['all_keys']:
        verify_file(pool, args.dir, *key)

    return f"verified {len(stats['all_keys'])} keys"


commands = {
    "stats": exec_stats_cmd,
    "unload": unload_all_keys
}


class UnknownCommandError(Exception):
    pass


def exec_cmd(pool, cmd_arr):
    cmd = commands.get(cmd_arr[0])
    if cmd is None:
        raise UnknownCommandError(f"Available commands: {list(commands.keys())}")
    return cmd(pool, *cmd_arr[1:])


# https://dev.to/amal/building-the-python-repl-3468
def exec_repl_cmd(pool, args):
    success = lambda input: f"{Fore.GREEN}{input}"
    failure = lambda input: f"{Fore.RED}{input}"

    init(autoreset=True)

    print()
    print(f"{Fore.GREEN}Welcome to DFS REPL")
    print(f"{Fore.BLUE}author: Brian Guarraci")
    print(f"{Fore.BLUE}repo  : https://github.com/briangu/dfs")
    print(f"{Fore.YELLOW}crtl-c to quit")
    print()

    try:
        while True:
            try:
                p = input("?> ")
                if len(p) == 0:
                    continue
                cmd_arr = shlex.split(p)
                if len(cmd_arr) == 0:
                    continue
                if cmd_arr[0] == 'exit':
                    break
                r = exec_cmd(pool, cmd_arr)
                print(success(r))
            except Exception as e:
                print(failure(f"Error: {e.args}"))
    except KeyboardInterrupt:
        print("\nExiting...")


parser = argparse.ArgumentParser(description='Run DataFrame Service command')
parser.add_argument('--port', type=int, help='specify alternate port (default: 8000)', default=8000)
parser.add_argument('--host', type=str, help='specify alternate host address (default: 127.0.0.1)', default="127.0.0.1")
parser.add_argument('--max_connections', type=int, help='specify alternate source data directory (default: 8)', default=8)
parser.add_argument('--log', type=str, help='specify alternate logging level (default: WARN)', default="WARN")

subparsers = parser.add_subparsers(help='sub-command help')
sp = subparsers.add_parser('repl', help='DFS REPL')
sp.set_defaults(func=exec_repl_cmd)

sp = subparsers.add_parser('scan', help='Get DFS stats')
sp.set_defaults(func=exec_scan_cmd)

sp = subparsers.add_parser('ls', help='List DFS keys')
sp.set_defaults(func=exec_ls_cmd)

sp = subparsers.add_parser('import', help='Import a directory structure')
sp.add_argument('--dir', type=str, help='specify alternate files directory (default: current dir)', default=os.getcwd())
sp.set_defaults(func=exec_import_cmd)

sp = subparsers.add_parser('set', help='Import a file')
sp.add_argument('key', type=str, help='specify DFS key path')
sp.add_argument('path', type=str, help='specify file path')
sp.set_defaults(func=exec_set_cmd)

sp = subparsers.add_parser('get', help='Export a file')
sp.add_argument('key', type=str, help='specify DFS key path')
sp.add_argument('path', type=str, help='specify file path')
sp.set_defaults(func=exec_get_cmd)

sp = subparsers.add_parser('verify', help='Verify directory')
sp.add_argument('--dir', type=str, help='specify alternate files directory (default: current dir)', default=os.getcwd())
sp.set_defaults(func=exec_verify_cmd)

args = parser.parse_args()

log_level = getattr(logging, args.log.upper(), None)
if not isinstance(log_level, int):
    raise ValueError('Invalid log level: %s' % args.log)
logging.basicConfig(level=log_level)

with DataFrameConnectionPool(args.host, args.port, max_connections=args.max_connections) as pool:
    print(args.func(pool, args))
