#! /usr/bin/env python3

import os
import cloudpickle
class GraphKeyResult:
    # extra_size_mb is used to allocate more space for this object in testing mode to evaluate storage consumption
    # and peer transfer performance across all workers.
    def __init__(self, result, extra_size_mb=None):
        self.result = result
        self.extra_obj = bytearray(int(extra_size_mb * 1024 * 1024)) if extra_size_mb and extra_size_mb > 0 else None

class TaskGraph:
    def __init__(self, task_dict,
                 shared_file_system_dir=None,
                 staging_dir=None,
                 extra_task_output_size_mb=["uniform", 0, 0],
                 extra_task_sleep_time=["uniform", 0, 0]):
        self.task_dict = task_dict
        self.shared_file_system_dir = shared_file_system_dir
        self.staging_dir = staging_dir

        if self.shared_file_system_dir:
            os.makedirs(self.shared_file_system_dir, exist_ok=True)

        if dts:
            for k, v in self.task_dict.items():
                if isinstance(v, dts.GraphNode):
                    assert isinstance(v, (dts.Alias, dts.Task, dts.DataNode)), f"Unsupported task type for key {k}: {v.__class__}"

        self.parents_of, self.children_of = self._build_dependencies(self.task_dict)
        self.depth_of = self._calculate_depths()

        self.vine_key_of = {k: hash_name(k) for k in task_dict.keys()}
        self.key_of_vine_key = {hash_name(k): k for k in task_dict.keys()}

        self.outfile_remote_name = {key: f"{uuid.uuid4()}.pkl" for key in self.task_dict.keys()}
        self.outfile_type = {key: None for key in self.task_dict.keys()}

        # testing params
        self.extra_task_output_size_mb = self._calculate_extra_size_mb_of(extra_task_output_size_mb)
        self.extra_sleep_time_of = self._calculate_extra_sleep_time_of(extra_task_sleep_time)

    def set_outfile_type_of(self, k, outfile_type_str):
        assert outfile_type_str in ["local", "shared-file-system", "temp"]
        self.outfile_type[k] = outfile_type_str
        if outfile_type_str == "shared-file-system":
            self.outfile_remote_name[k] = os.path.join(self.shared_file_system_dir, self.outfile_remote_name[k])

    def _calculate_extra_size_mb_of(self, extra_task_output_size_mb):
        assert isinstance(extra_task_output_size_mb, list) and len(extra_task_output_size_mb) == 3
        mode, low, high = extra_task_output_size_mb
        low, high = int(low), int(high)
        assert low <= high

        max_depth = max(depth for depth in self.depth_of.values())
        extra_size_mb_of = {}
        for k in self.task_dict.keys():
            if self.depth_of[k] == max_depth or self.depth_of[k] == max_depth - 1:
                extra_size_mb_of[k] = 0
                continue
            extra_size_mb_of[k] = dist_func(mode, low, high)

        return extra_size_mb_of

    def _calculate_extra_sleep_time_of(self, extra_task_sleep_time):
        assert isinstance(extra_task_sleep_time, list) and len(extra_task_sleep_time) == 3
        mode, low, high = extra_task_sleep_time
        low, high = int(low), int(high)
        assert low <= high

        extra_sleep_time_of = {}
        for k in self.task_dict.keys():
            extra_sleep_time_of[k] = dist_func(mode, low, high)

        return extra_sleep_time_of

    def _calculate_depths(self):
        depth_of = {key: 0 for key in self.task_dict.keys()}

        topo_order = self.get_topological_order()
        for key in topo_order:
            if self.parents_of[key]:
                depth_of[key] = max(depth_of[parent] for parent in self.parents_of[key]) + 1
            else:
                depth_of[key] = 0

        return depth_of

    def set_outfile_remote_name_of(self, key, outfile_remote_name):
        self.outfile_remote_name[key] = outfile_remote_name

    def is_dts_key(self, k):
        if not hasattr(dask, "_task_spec"):
            return False
        import dask._task_spec as dts
        return isinstance(self.task_dict[k], (dts.Task, dts.TaskRef, dts.Alias, dts.DataNode, dts.NestedContainer))

    def _build_dependencies(self, task_dict):
        def _find_sexpr_parents(sexpr):
            if hashable(sexpr) and sexpr in task_dict.keys():
                return {sexpr}
            elif isinstance(sexpr, (list, tuple)):
                deps = set()
                for x in sexpr:
                    deps |= _find_sexpr_parents(x)
                return deps
            elif isinstance(sexpr, dict):
                deps = set()
                for k, v in sexpr.items():
                    deps |= _find_sexpr_parents(k)
                    deps |= _find_sexpr_parents(v)
                return deps
            else:
                return set()

        parents_of = collections.defaultdict(set)
        children_of = collections.defaultdict(set)

        for k, value in task_dict.items():
            if self.is_dts_key(k):
                # in the new Dask expression, each value is an object from dask._task_spec, could be
                # a Task, Alias, TaskRef, etc., but they all share the same base class the dependencies
                # field is of type frozenset(), without recursive ancestor dependencies involved
                parents_of[k] = value.dependencies
            else:
                # the value could be a sexpr, e.g., the old Dask representation
                parents_of[k] = _find_sexpr_parents(value)

        for k, deps in parents_of.items():
            for dep in deps:
                children_of[dep].add(k)

        return parents_of, children_of

    def save_result_of_key(self, key, result):
        with open(self.outfile_remote_name[key], "wb") as f:
            result_obj = GraphKeyResult(result, extra_size_mb=self.extra_task_output_size_mb[key])
            cloudpickle.dump(result_obj, f)

    def load_result_of_key(self, key):
        try:
            with open(self.outfile_remote_name[key], "rb") as f:
                result_obj = cloudpickle.load(f)
                assert isinstance(result_obj, GraphKeyResult), "Loaded object is not of type GraphKeyResult"
                return result_obj.result
        except FileNotFoundError:
            raise FileNotFoundError(f"Output file for key {key} not found at {self.outfile_remote_name[key]}")

    def get_topological_order(self):
        in_degree = {key: len(self.parents_of[key]) for key in self.task_dict.keys()}
        queue = deque([key for key, degree in in_degree.items() if degree == 0])
        topo_order = []

        while queue:
            current = queue.popleft()
            topo_order.append(current)

            for child in self.children_of[current]:
                in_degree[child] -= 1
                if in_degree[child] == 0:
                    queue.append(child)

        if len(topo_order) != len(self.task_dict):
            print(f"len(topo_order): {len(topo_order)}")
            print(f"len(self.task_dict): {len(self.task_dict)}")
            raise ValueError("Failed to create topo order, the dependencies may be cyclic or problematic")

        return topo_order

    def __del__(self):
        if hasattr(self, 'outfile_remote_name') and self.outfile_remote_name:
            for k in self.outfile_remote_name.keys():
                if self.outfile_type.get(k) == "shared-file-system" and os.path.exists(self.outfile_remote_name[k]):
                    os.remove(self.outfile_remote_name[k])

import uuid
import hashlib
import random
import types
import collections
import time
def load_variable_from_library(var_name):
    return globals()[var_name]

def compute_dts_key(task_graph, k, v):
    try:
        import dask._task_spec as dts
    except ImportError:
        raise ImportError("Dask is not installed")

    input_dict = {dep: task_graph.load_result_of_key(dep) for dep in v.dependencies}

    try:
        if isinstance(v, dts.Alias):
            assert len(v.dependencies) == 1, "Expected exactly one dependency"
            return task_graph.load_result_of_key(next(iter(v.dependencies)))
        elif isinstance(v, dts.Task):
            return v(input_dict)
        elif isinstance(v, dts.DataNode):
            return v.value
        else:
            raise TypeError(f"unexpected node type: {type(v)} for key {k}")
    except Exception as e:
        raise Exception(f"Error while executing task {k}: {e}")

def compute_sexpr_key(task_graph, k, v):
    input_dict = {parent: task_graph.load_result_of_key(parent) for parent in task_graph.parents_of[k]}

    def _rec_call(expr):
        try:
            if expr in input_dict.keys():
                return input_dict[expr]
        except TypeError:
            pass
        if isinstance(expr, list):
            return [_rec_call(e) for e in expr]
        if isinstance(expr, tuple) and len(expr) > 0 and callable(expr[0]):
            res = expr[0](*[_rec_call(a) for a in expr[1:]])
            return res
        return expr

    try:
        return _rec_call(v)
    except Exception as e:
        raise Exception(f"Failed to invoke _rec_call(): {e}")

def compute_single_key(vine_key):
    task_graph = load_variable_from_library('task_graph')

    k = task_graph.key_of_vine_key[vine_key]
    v = task_graph.task_dict[k]

    if task_graph.is_dts_key(k):
        result = compute_dts_key(task_graph, k, v)
    else:
        result = compute_sexpr_key(task_graph, k, v)

    task_graph.save_result_of_key(k, result)
    if not os.path.exists(task_graph.outfile_remote_name[k]):
        raise Exception(f"Output file {task_graph.outfile_remote_name[k]} does not exist after writing")
    if os.stat(task_graph.outfile_remote_name[k]).st_size == 0:
        raise Exception(f"Output file {task_graph.outfile_remote_name[k]} is empty after writing")

    time.sleep(task_graph.extra_sleep_time_of[k])

    return True

def hash_name(*args):
    out_str = ""
    for arg in args:
        out_str += str(arg)
    return hashlib.sha256(out_str.encode('utf-8')).hexdigest()[:32]

def hashable(s):
    try:
        hash(s)
        return True
    except TypeError:
        return False

import dask
# Copyright (C) 2022 The University of Notre Dame
# This software is distributed under the GNU General Public License.
# See the file COPYING for details.


# This file serves as the template for Python Library Task.
# A Python Library Task runs on a worker as a pilot task. Upcoming Python
# Function Calls will be executed by this pilot task.

# import relevant libraries.
import json
import os
import fcntl
import sys
import argparse
import traceback
import cloudpickle
import select
import signal
import time
from datetime import datetime
import socket
from threadpoolctl import threadpool_limits
from ndcctools.taskvine.utils import load_variable_from_library

# self-pipe to turn a sigchld signal when a child finishes execution
# into an I/O event.
r, w = os.pipe()
exec_method = None
infile_load_mode = None


# This class captures how results from FunctionCalls are conveyed from
# the library to the manager.
# For now, all communication details should use this class to generate responses.
# In the future, this common protocol should be listed someplace else
# so library tasks from other languages can use.
class LibraryResponse:
    def __init__(self, result=None, success=None, reason=None):
        self.result = result
        self.success = success
        self.reason = reason

    def generate(self):
        return {
            "Result": self.result,
            "Success": self.success,
            "Reason": self.reason,
        }


# A wrapper around functions in library to extract arguments and formulate responses.
def remote_execute(func):
    def remote_wrapper(event):
        if infile_load_mode == "cloudpickle":
            args = event.get("fn_args", [])
            kwargs = event.get("fn_kwargs", {})
        elif infile_load_mode == "text":
            args = [event]
            kwargs = {}
        else:
            raise ValueError(f"Invalid infile load mode: {infile_load_mode}, only 'cloudpickle' and 'text' are supported")

        # in case of FutureFunctionCall tasks
        new_args = []
        for arg in args:
            if isinstance(arg, dict) and "VineFutureFile" in arg:
                with open(arg["VineFutureFile"], "rb") as f:
                    output = cloudpickle.load(f)["Result"]
                    new_args.append(output)
            else:
                new_args.append(arg)
        args = tuple(new_args)

        try:
            result = func(*args, **kwargs)
            success = True
            reason = None
        except Exception:
            result = None
            success = False
            reason = traceback.format_exc()
        return LibraryResponse(result, success, reason).generate()

    return remote_wrapper


# Handler to sigchld when child exits.
def sigchld_handler(signum, frame):
    # write any byte to signal that there's at least 1 child
    os.writev(w, [b"a"])


# Read data from worker, start function, and dump result to `outfile`.
def start_function(in_pipe_fd, thread_limit=1):
    # read length of buffer to read
    buffer_len = b""
    while True:
        c = os.read(in_pipe_fd, 1)
        if c == b"":
            stdout_timed_message(f"can't get length from in_pipe_fd {in_pipe_fd}")
            exit(1)
        elif c == b"\n":
            break
        else:
            buffer_len += c
    buffer_len = int(buffer_len)
    # now read the buffer to get invocation details
    line = str(os.read(in_pipe_fd, buffer_len), encoding="utf-8")

    try:
        (
            function_id,
            function_name,
            function_sandbox,
            function_stdout_filename
        ) = line.split(" ", maxsplit=3)
    except Exception as e:
        stdout_timed_message(f"error: not enough values to unpack from {line} (expected 4 items), exception: {e}")
        exit(1)

    try:
        function_id = int(function_id)
    except Exception as e:
        stdout_timed_message(f"error: can't turn {function_id} into an integer, exception: {e}")
        exit(1)

    if not function_name:
        # malformed message from worker so we exit
        stdout_timed_message(f"error: invalid function name, malformed message {line} from worker")
        exit(1)

    with threadpool_limits(limits=thread_limit):
        if exec_method == "direct":
            library_sandbox = os.getcwd()
            try:
                os.chdir(function_sandbox)

                # parameters are represented as infile.
                with open("infile", "rb") as f:
                    if infile_load_mode == "cloudpickle":
                        event = cloudpickle.load(f)
                    elif infile_load_mode == "text":
                        event = f.read().decode("utf-8")
                    else:
                        raise ValueError(f"Invalid infile load mode: {infile_load_mode}, only 'cloudpickle' and 'text' are supported")

                # output of execution should be dumped to outfile.
                result = globals()[function_name](event)
                try:
                    with open("outfile", "wb") as f:
                        cloudpickle.dump(result, f)
                except Exception:
                    if os.path.exists("outfile"):
                        os.remove("outfile")
                    raise

                try:
                    if not result["Success"]:
                        raise Exception(result["Reason"])
                except Exception:
                    raise

            except Exception:
                stdout_timed_message(
                    f"Library code: Function call failed due to {traceback.format_exc()}",
                    file=sys.stderr,
                )
                sys.exit(1)
            finally:
                os.chdir(library_sandbox)
            return -1, function_id
        else:
            try:
                arg_infile = os.path.join(function_sandbox, "infile")
                with open(arg_infile, "rb") as f:
                    if infile_load_mode == "cloudpickle":
                        event = cloudpickle.load(f)
                    elif infile_load_mode == "text":
                        event = f.read().decode("utf-8")
                    else:
                        raise ValueError(f"Invalid infile load mode: {infile_load_mode}, only 'cloudpickle' and 'text' are supported")
            except Exception:
                stdout_timed_message(f"TASK {function_id} error: can't load the arguments from {arg_infile} due to {traceback.format_exc()}")
                return -1, function_id
            p = os.fork()
            if p == 0:
                exit_status = None
                try:
                    # change the working directory to the function's sandbox
                    os.chdir(function_sandbox)

                    stdout_timed_message(f"TASK {function_id} {function_name} arrives, starting to run in process {os.getpid()}")

                    try:
                        # setup stdout/err for a function call so we can capture them.
                        function_stdout_fd = os.open(
                            function_stdout_filename, os.O_WRONLY | os.O_CREAT | os.O_TRUNC
                        )
                        # store the library's stdout fd
                        library_fd = os.dup(sys.stdout.fileno())

                        # only redirect the stdout of a specific FunctionCall task into its own stdout fd,
                        # otherwise use the library's stdout
                        # os.dup2(function_stdout_fd, sys.stdout.fileno())
                        # os.dup2(function_stdout_fd, sys.stderr.fileno())
                        stdout_timed_message(f"TASK {function_id} {function_name} is starting")
                        result = globals()[function_name](event)
                        stdout_timed_message(f"TASK {function_id} {function_name} finished, result size: {sys.getsizeof(result)}")

                        # restore to the library's stdout fd on completion
                        os.dup2(library_fd, sys.stdout.fileno())
                    except Exception:
                        stdout_timed_message(f"TASK {function_id} error: can't execute this function due to {traceback.format_exc()}")
                        exit_status = 3
                        raise
                    finally:
                        if function_stdout_fd in locals():
                            os.close(function_stdout_fd)

                    try:
                        with open("outfile", "wb") as f:
                            cloudpickle.dump(result, f)
                        stdout_timed_message(f"TASK {function_id} result: {result}")
                    except Exception:
                        stdout_timed_message(f"TASK {function_id} error: can't load the result from outfile due to {traceback.format_exc()}")
                        exit_status = 4
                        if os.path.exists("outfile"):
                            os.remove("outfile")
                        raise

                    try:
                        if not result["Success"]:
                            exit_status = 5
                    except Exception:
                        stdout_timed_message(f"TASK {function_id} error: the result is invalid due to {traceback.format_exc()}")
                        exit_status = 5
                        raise

                    if exit_status is None:
                        stdout_timed_message(f"TASK {function_id} finished successfully")
                        exit_status = 0
                except Exception:
                    stdout_timed_message(f"TASK {function_id} error: execution failed due to {traceback.format_exc()}")
                finally:
                    os._exit(exit_status)
            elif p < 0:
                stdout_timed_message(f"TASK {function_id} error: unable to fork to execute {function_name} due to {traceback.format_exc()}")
                return -1, function_id

            # return pid and function id of child process to parent.
            else:
                return p, function_id


# Send result of a function execution to worker. Wake worker up to do work with SIGCHLD.
def send_result(out_pipe_fd, worker_pid, task_id, exit_code):
    buff = bytes(f"{task_id} {exit_code}", "utf-8")
    buff = bytes(str(len(buff)), "utf-8") + b"\n" + buff
    os.writev(out_pipe_fd, [buff])
    os.kill(worker_pid, signal.SIGCHLD)


# Self-identifying message to send back to the worker, including the name of this library.
# Send back a SIGCHLD to interrupt worker sleep and get it to work.
def send_configuration(config, out_pipe_fd, worker_pid):
    config_string = json.dumps(config)
    config_cmd = f"{len(config_string)}\n{config_string}"
    os.writev(out_pipe_fd, [bytes(config_cmd, "utf-8")])
    os.kill(worker_pid, signal.SIGCHLD)


# Use os.write to stdout instead of print for multi-processing safety
def stdout_timed_message(message):
    timestamp = datetime.now().strftime("%m/%d/%y %H:%M:%S.%f")
    os.write(sys.stdout.fileno(), f"{timestamp} {message}\n".encode())


def main():
    ppid = os.getppid()

    parser = argparse.ArgumentParser(
        "Parse input and output file descriptors this process should use. The relevant fds should already be prepared by the vine_worker."
    )
    parser.add_argument(
        "--in-pipe-fd",
        required=True,
        type=int,
        help="input fd to receive messages from the vine_worker via a pipe",
    )
    parser.add_argument(
        "--out-pipe-fd",
        required=True,
        type=int,
        help="output fd to send messages to the vine_worker via a pipe",
    )
    parser.add_argument(
        "--task-id",
        required=False,
        type=int,
        default=-1,
        help="task id for this library.",
    )
    parser.add_argument(
        "--library-cores",
        required=False,
        type=int,
        default=1,
        help="number of cores of this library",
    )
    parser.add_argument(
        "--function-slots",
        required=False,
        type=int,
        default=1,
        help="number of function slots of this library",
    )
    parser.add_argument(
        "--worker-pid",
        required=True,
        type=int,
        help="pid of main vine worker to send sigchild to let it know theres some result.",
    )
    args = parser.parse_args()

    # check if library cores and function slots are valid
    if args.function_slots > args.library_cores:
        stdout_timed_message("error: function slots cannot be more than library cores")
        exit(1)
    elif args.function_slots < 1:
        stdout_timed_message("error: function slots cannot be less than 1")
        exit(1)
    elif args.library_cores < 1:
        stdout_timed_message("error: library cores cannot be less than 1")
        exit(1)

    try:
        thread_limit = args.library_cores // args.function_slots
    except Exception as e:
        stdout_timed_message(f"error: {e}")
        exit(1)

    # check if the in_pipe_fd and out_pipe_fd are valid
    try:
        fcntl.fcntl(args.in_pipe_fd, fcntl.F_GETFD)
        fcntl.fcntl(args.out_pipe_fd, fcntl.F_GETFD)
    except IOError as e:
        stdout_timed_message(f"error: pipe fd closed\n{e}")
        exit(1)

    stdout_timed_message(f"library task starts running in process {os.getpid()}")
    stdout_timed_message(f"hostname             {socket.gethostname()}")
    stdout_timed_message(f"task id              {args.task_id}")
    stdout_timed_message(f"worker pid           {args.worker_pid}")
    stdout_timed_message(f"library pid          {os.getpid()}")
    stdout_timed_message(f"input fd             {args.in_pipe_fd}")
    stdout_timed_message(f"output fd            {args.out_pipe_fd}")
    stdout_timed_message(f"library cores        {args.library_cores}")
    stdout_timed_message(f"function slots       {args.function_slots}")
    stdout_timed_message(f"thread limit         {thread_limit}")

    # Open communication pipes to vine_worker.
    # The file descriptors are inherited from the vine_worker parent process
    # and should already be open for reads and writes.
    in_pipe_fd = args.in_pipe_fd
    out_pipe_fd = args.out_pipe_fd

    # mapping of child pid to function id of currently running functions
    pid_to_func_id = {}

    # read in information about this library
    with open('library_info.clpk', 'rb') as f:
        library_info = cloudpickle.load(f)

    # load and execute this library's context
    library_context_info = cloudpickle.loads(library_info['context_info'])
    context_vars = None
    if library_context_info:
        context_func = library_context_info[0]
        context_args = library_context_info[1]
        context_kwargs = library_context_info[2]
        context_vars = context_func(*context_args, **context_kwargs)

    # register functions in this library to the global namespace
    for func_name in library_info['function_list']:
        func_code = remote_execute(cloudpickle.loads(library_info['function_list'][func_name]))
        globals()[func_name] = func_code

    # update library's context to the load function
    if context_vars:
        (load_variable_from_library.__globals__).update(context_vars)

    # set execution mode of functions in this library
    global exec_method
    exec_method = library_info['exec_mode']

    # set infile load mode of functions in this library
    global infile_load_mode
    infile_load_mode = library_info['infile_load_mode']

    # send configuration of library, just its name for now
    config = {
        "name": library_info['library_name'],
        "taskid": args.task_id,
        "exec_mode": exec_method,
        "infile_load_mode": infile_load_mode,
    }
    send_configuration(config, out_pipe_fd, args.worker_pid)

    # register sigchld handler to turn a sigchld signal into an I/O event
    signal.signal(signal.SIGCHLD, sigchld_handler)

    # 5 seconds to wait for select, any value long enough would probably do
    timeout = 5

    last_check_time = time.time() - 5

    while True:
        # check if parent exits
        c_ppid = os.getppid()
        if c_ppid != ppid or c_ppid == 1:
            stdout_timed_message("library finished because parent exited")
            exit(0)

        # periodically log the number of concurrent functions
        current_check_time = time.time()
        if current_check_time - last_check_time >= 5:
            stdout_timed_message(f"{len(pid_to_func_id)} functions running concurrently")
            last_check_time = current_check_time

        # in case of "fork" exec method, wait for messages from worker or child to return
        # in case of "direct" exec method, wait for messages from worker
        try:
            rlist, wlist, xlist = select.select([in_pipe_fd, r], [], [], timeout)
        except Exception as e:
            stdout_timed_message(f"error unable to read from pipe {in_pipe_fd}\n{e}")

        for re in rlist:
            # worker has a function, run it
            if re == in_pipe_fd:
                if exec_method == 'direct':
                    _, func_id = start_function(in_pipe_fd, thread_limit)
                    send_result(
                        out_pipe_fd,
                        args.worker_pid,
                        func_id,
                        0,
                    )
                else:
                    pid, func_id = start_function(in_pipe_fd, thread_limit)
                    pid_to_func_id[pid] = func_id
                    stdout_timed_message(f"Task {func_id} started in process {pid}")
            else:
                # at least 1 child exits, reap all.
                # read only once as os.read is blocking if there's nothing to read.
                # note that there might still be bytes in `r` but it's ok as they will
                # be discarded in the next iterations.
                os.read(r, 1)
                while len(pid_to_func_id) > 0:
                    c_pid, c_exit_status = os.waitpid(-1, os.WNOHANG)
                    if c_pid > 0:
                        stdout_timed_message(f"Task {pid_to_func_id[c_pid]} exited with status {c_exit_status}")
                        send_result(
                            out_pipe_fd,
                            args.worker_pid,
                            pid_to_func_id[c_pid],
                            c_exit_status,
                        )
                        del pid_to_func_id[c_pid]
                    # no exited child to reap, break
                    else:
                        break
    return 0


if __name__ == '__main__':
    main()


# vim: set sts=4 sw=4 ts=4 expandtab ft=python:
