import inspect
import dis
import ctypes
from collections import namedtuple
from functools import partial
from types import CodeType, FunctionType
import logging
from importlib._bootstrap_external import _code_to_timestamp_pyc

import subprocess
import base64
from shlex import quote
import dill

from .mem_view import Mem, ptr_frame_stack_bottom, ptr_frame_stack_top
from .minias import _dis, Bytecode, long2bytes

locals().update(dis.opmap)


def _overlapping(s1, l1, s2, l2):
    e1 = s1 + l1
    e2 = s2 + l2
    return s1 < e2 and s2 < e1


class CodePatcher(dict):
    """Collects and applies patches to bytecodes."""
    def __init__(self, code):
        self._code = code

    def __str__(self):
        return f"CodePatcher(code={self._code})"

    def _diff(self):
        _new = list(self._code.co_code)
        for pos, patch in self.items():
            _new[pos:pos + len(patch)] = patch
        return _dis(self._code, alt=_new)

    def commit(self):
        logging.debug(f"Commit patch to <{self._code.co_name}>")
        for i in self._diff():
            logging.debug(''.join(i))
        code = self._code.co_code
        code_view = Mem.view(code)
        for pos, patch in self.items():
            assert len(patch) <= len(code), f"len(patch) = {len(patch)} > len(code) = {len(code)}"
            assert 0 <= pos <= len(code) - len(patch), f"Index {pos:d} out of range [0, {len(code) - len(patch)}]"
            code_view[pos:pos + len(patch)] = patch
        self.clear()

    @property
    def last_opcode(self):
        return self._code.co_code[-2]

    def __setitem__(self, pos, patch):
        patch = bytes(patch)
        code = self._code.co_code
        assert len(patch) <= len(code), f"len(patch) = {len(patch)} > len(code) = {len(code)}"
        assert 0 <= pos <= len(code) - len(patch), f"Index {pos:d} out of range [0, {len(code) - len(patch)}]"
        for _pos, _other in self.items():
            if _overlapping(pos, len(patch), _pos, len(_other)):
                raise ValueError("Patches overlap")
        super().__setitem__(pos, patch)

    def patch(self, patch, pos):
        self[pos] = patch

    def __len__(self):
        return len(self._code.co_code)


class FramePatcher(CodePatcher):
    """Collects and applies patches to bytecodes."""
    def __init__(self, frame):
        self._frame = frame
        super().__init__(frame.f_code)

    def __str__(self):
        return f"FramePatcher(frame={self._frame})"

    @property
    def pos(self):
        return self._frame.f_lasti

    def _diff(self):
        result_ = super()._diff()
        result = []
        for i, l in enumerate(result_):
            if 2 * i == self.pos:
                if l[0].startswith('\033'):
                    result.append(('\033[92m', *l[1:]))
                else:
                    result.append(('\033[92m', *l, '\033[0m'))
            else:
                result.append(l)
        return result

    @property
    def current_opcode(self):
        return self._code.co_code[self.pos]

    def patch_current(self, patch, pos):
        return self.patch(patch, pos + self.pos)


def expand_long(c):
    """Expands opcode arguments if they do not fit byte"""
    result = []
    for opcode, val in zip(c[::2], c[1::2]):
        if not val:
            result.extend([opcode, val])
        else:
            bts = long2bytes(val)
            for b in bts[:-1]:
                result.extend([EXTENDED_ARG, b])
            result.extend([opcode, bts[-1]])
    return bytes(result)


def get_value_stack_from_beacon(frame, beacon, expand=0):
    """
    Collects frame stack using beacon as
    an indicator of stack top.

    Parameters
    ----------
    frame : FrameObject
        Frame to process.
    beacon : int
        Value on top of the stack.
    expand : int

    Returns
    -------
    stack : list
        Stack contents.
    """
    stack_bot = ptr_frame_stack_bottom(frame)
    stack_view = Mem(stack_bot, (frame.f_code.co_stacksize + expand) * 8)[:]
    result = []
    for i in range(0, len(stack_view), 8):
        obj_ref = int.from_bytes(stack_view[i:i + 8], "little")
        if obj_ref == beacon:
            return result
        result.append(ctypes.cast(obj_ref, ctypes.py_object).value)
    raise RuntimeError("Failed to determine stack top")


def get_value_stack(frame):
    """
    Collects frame stack for generator objects.

    Parameters
    ----------
    frame : FrameObject
        Frame to process.

    Returns
    -------
    stack : list
        Stack contents.
    """
    stack_bot = ptr_frame_stack_bottom(frame)
    stack_top = ptr_frame_stack_top(frame)
    data = Mem(stack_bot, stack_top - stack_bot)[:]
    result = []
    for i in range(0, len(data), 8):
        obj_ref = int.from_bytes(data[i:i + 8], "little")
        result.append(ctypes.cast(obj_ref, ctypes.py_object).value)
    return result


class FrameSnapshot(namedtuple("FrameSnapshot", ("code", "pos", "v_stack", "v_locals", "v_globals", "v_builtins"))):
    """A snapshot of python frame"""
    slots = ()
    def __repr__(self):
        code = self.code
        contents = []
        for i in "v_stack", "v_locals", "v_globals", "v_builtins":
            v = getattr(self, i)
            if v is None:
                contents.append(f"{i}: not set")
            else:
                contents.append(f"{i}: {len(v):d}")
        return f'FrameSnapshot {code.co_name} at "{code.co_filename}"+{code.co_firstlineno} @{self.pos:d} {" ".join(contents)}'


def p_jump_to(pos, patcher, f_next):
    """
    Patch: jump to position.

    Parameters
    ----------
    pos : int
        Position to set.
    patcher : FramePatcher
    f_next : Callable

    Returns
    -------
    f_next : Callable
        Next function to call.
    """
    if patcher.pos == pos - 2:
        if f_next is not None:
            return f_next()  # already at the top: execute next
    else:
        logging.debug(f"jump_to {pos:d}: patching ...")
        if patcher.pos != pos - 2:
            patcher.patch_current(expand_long([JUMP_ABSOLUTE, pos]), 2)  # jump to the original bytecode position
        patcher.patch([CALL_FUNCTION, 0], pos)  # call next
        patcher.commit()
        logging.debug(f"jump_to {pos:d}: ⏎ {f_next}")
        return f_next


def p_set_bytecode(bytecode, post, patcher, f_next):
    """
    Patch: set the bytecode contents.

    Parameters
    ----------
    bytecode : bytearray
        Bytecode to overwrite.
    post : Callable
        Call this before returning.
    patcher : FramePatcher
    f_next : Callable

    Returns
    -------
    f_next : Callable
        Next function to call.
    """
    logging.debug(f"set_bytecode: patching ...")
    patcher.patch(bytecode, 0)  # re-write the bytecode from scratch
    patcher.commit()
    if post is not None:
        post()
    logging.debug(f"set_bytecode: ⏎ {f_next}")
    return f_next


def p_place_beacon(beacon, patcher, f_next):
    """
    Patch: places the beacon.

    Parameters
    ----------
    beacon
        Beacon to place.
    patcher : FramePatcher
    f_next : Callable

    Returns
    -------
    f_next : Callable
        Next function to call.
    """
    logging.debug(f"place_beacon {beacon}: patching ...")
    patcher.patch_current([
        UNPACK_SEQUENCE, 2,
        CALL_FUNCTION, 0,  # calls _payload1
        CALL_FUNCTION, 0,  # calls whatever follows
    ], 2)
    patcher.commit()
    logging.debug(f"place_beacon {beacon}: ⏎ ({f_next}, {beacon})")
    return f_next, beacon


def snapshot(frame, finalize, method="inject"):
    """
    Snapshot the stack starting from the given frame.

    Parameters
    ----------
    frame : FrameObject
        Top of the stack frame.
    finalize : Callable
        Where to return the result.
    method : {"inject", "direct"}
        Method to use for the stack:
        * `inject`: makes a snapshot of an active stack by
          patching stack frames and running bytecode snippets
          inside. The stack is destroyed and the result is
          returned into `finalize` function (required).
        * `direct`: makes a snapshot of an inactive stack
          by reading FrameObject structure fields. Can only
          be used with generator frames.

    Returns
    -------
    rtn : object
        Depending on the method, this is either the snapshot
        itself or an object that has to be returned to the
        subject frame to initiate invasive frame collection.
    """
    assert method in {"inject", "direct"}
    if method == "inject" and finalize is None:
        raise ValueError("For method='inject' finalize has to set")
    # determine the frame to start with
    logging.debug(f"Start frame serialization; mode: {'active' if finalize is not None else 'inactive'}")
    if frame is None:
        logging.info("  no frame specified")
        frame = 1
    if isinstance(frame, int):
        logging.info(f"  taking frame #{frame:d}")
        _frame = inspect.currentframe()
        for i in range(frame):
            _frame = _frame.f_back
        frame = _frame

    logging.info(f"  frame: {frame}")

    result = []
    if method == "inject":  # prepare to recieve data from patched frames
        beacon = object()  # beacon object

        notify_current = 0
        def notify(frame, f_next):
            """A callback to save stack items"""
            nonlocal notify_current, beacon
            logging.debug(f"Identify/collect object stack ...")
            result[notify_current] = result[notify_current]._replace(
                v_stack=get_value_stack_from_beacon(frame, id(beacon), expand=1))  # this might corrupt memory
            logging.info(f"  received {len(result[notify_current].v_stack):d} items")
            notify_current += 1
            return f_next

        chain = []  # holds a chain of patches and callbacks

    prev_globals = None
    prev_builtins = None

    while frame is not None:  # iterate over frame stack
        logging.info(f"Frame: {frame}")

        # check globals and builtins
        if prev_globals is None:
            prev_globals = frame.f_globals
        else:
            assert prev_globals is frame.f_globals
        if prev_builtins is None:
            prev_builtins = frame.f_builtins
        else:
            assert prev_builtins is frame.f_builtins

        # save locals, globals, etc.
        logging.info("  saving locals ...")
        result.append(FrameSnapshot(
            code=frame.f_code,
            pos=frame.f_lasti,
            v_stack=None if method == "inject" else get_value_stack(frame),
            v_locals=frame.f_locals.copy(),
            v_globals=prev_globals,
            v_builtins=prev_builtins,
        ))

        if method == "inject":  # prepare patchers
            logging.info(f"  patching the bytecode ...")
            original_code = bytearray(frame.f_code.co_code)  # store the original bytecode
            rtn_pos = original_code[::2].index(RETURN_VALUE) * 2  # figure out where it returns
            # note that bytearray is intentional to guarantee the copy
            patcher = FramePatcher(frame)

            p_jump_to(0, patcher, None)  # make room for patches immediately
            chain.append(partial(p_place_beacon, beacon, patcher))  # place the beacon
            chain.append(partial(notify, frame))  # collect value stack
            chain.append(partial(p_jump_to, rtn_pos - 2, patcher))  # jump 1 opcode before return
            chain.append(partial(
                p_set_bytecode,
                original_code,
                None
                if frame.f_back is not None
                else partial(finalize, result),
                patcher
            ))  # restore the bytecode

        frame = frame.f_back  # next frame

    if method == "inject":  # chain patches
        prev = None
        for i in chain[::-1]:
            prev = partial(i, prev)
        logging.info("Ready to collect frames")
        return prev

    else:
        logging.info("Snapshot ready")
        return result


def unpickle_generator(code):
    """
    Unpickles the generator.

    Parameters
    ----------
    code : Codetype
        The morph code.

    Returns
    -------
    result
        The generator.
    """
    return FunctionType(code, globals())()


def _():
    yield None


@dill.register(type(_()))
def pickle_generator(pickler, obj):
    """
    Pickles generators.

    Parameters
    ----------
    pickler
        The pickler.
    obj
        The generator.
    """
    code = morph_stack(snapshot(obj.gi_frame, None, method="direct"), globals=False, flags=0x20)
    pickler.save_reduce(
        unpickle_generator,
        (code,),
        obj=obj,
    )


def morph_execpoint(p, nxt, pack=None, unpack=None, globals=False, fake_return=True,
        flags=0):
    """
    Prepares a code object which morphs into the desired state
    and continues the execution afterwards.

    Parameters
    ----------
    p : execpoint
        The execution point to morph into.
    nxt : CodeType
        The code object which develops the stack further.
    pack : Callable, None
        A method turning objects into bytes (serializer)
        locally.
    unpack : tuple, None
        A 2-tuple `(module_name, method_name)` specifying
        the method that morph uses to unpack the data.
    globals : bool
        If True, unpacks globals.
    fake_return : bool
        If set, fakes returning None by putting None on top
        of the stack. This will be ignored if nxt is not
        None.
    flags : int
        Code object flags.

    Returns
    -------
    result : CodeType
        The resulting morph.
    """
    assert pack is None and unpack is None or pack is not None and unpack is not None,\
        "Either both or none pack and unpack arguments should be specified"
    logging.info(f"Preparing a morph into execpoint {p} pack={pack is not None} ...")
    code = Bytecode.disassemble(p.code)
    code.pos = 0
    code.c("Header")
    code.nop(b'mrph')  # signature
    f_code = p.code
    new_stacksize = f_code.co_stacksize

    if pack:
        unpack_mod, unpack_method = unpack
        code.c(f"from {unpack_mod} import {unpack_method}")
        unpack = code.varnames('.:unpack:.')  # non-alphanumeric = unlikely to exist as a proper variable
        code.I(LOAD_CONST, 0)
        code.I(LOAD_CONST, (unpack_method,))
        code.I(IMPORT_NAME, unpack_mod)
        code.I(IMPORT_FROM, unpack_method)
        code.i(STORE_FAST, unpack)

        def _LOAD(_what):
            code.i(LOAD_FAST, unpack)
            code.I(LOAD_CONST, pack(_what))
            code.i(CALL_FUNCTION, 1)
    else:
        def _LOAD(_what):
            code.I(LOAD_CONST, _what)

    scopes = [(p.v_locals, STORE_FAST, "locals")]
    if globals:
        scopes.append((p.v_globals, STORE_GLOBAL, "globals"))
    for _dict, _STORE, log_name in scopes:
        logging.info(f"  {log_name} ...")
        if len(_dict) > 0:
            code.c(f"{log_name} = ...")
            klist, vlist = zip(*_dict.items())
            _LOAD(vlist)
            code.i(UNPACK_SEQUENCE, len(vlist))
            for k in klist:
                # k = v
                code.I(_STORE, k)
            new_stacksize = max(new_stacksize, len(vlist))

    # stack
    if len(p.v_stack) > 0:
        code.c(f"*stack")
        v_stack = p.v_stack[::-1]
        _LOAD(v_stack)
        code.i(UNPACK_SEQUENCE, len(v_stack))

    if nxt is not None:
        # call nxt which is a code object
        code.c(f"nxt()")

        # load code object
        _LOAD(nxt)
        code.I(LOAD_CONST, None)  # function name
        code.i(MAKE_FUNCTION, 0)  # turn code object into a function
        code.i(CALL_FUNCTION, 0)  # call it
    elif fake_return:
        code.c(f"fake return None")
        code.I(LOAD_CONST, None)  # fake nxt returning None

    # now jump to the previously saved position
    target_pos = p.pos + 2  # p.pos points to the last executed opcode
    # find the instruction ...
    for jump_target in code.iter_opcodes():
        if jump_target.pos == target_pos:
            break
    else:
        raise RuntimeError
    # ... and jump to it (the argument will be determined after re-assemblling the bytecode)
    code.c(f"goto saved pos")
    code.i(JUMP_ABSOLUTE, 0, jump_to=jump_target)

    code.c(f"---------------------")
    code.c(f"The original bytecode")
    code.c(f"---------------------")
    result = CodeType(
        0,
        0,
        0,
        len(code.varnames),
        new_stacksize + 1,
        flags,
        code.get_bytecode(),
        tuple(code.consts),
        tuple(code.names),
        tuple(code.varnames),
        f_code.co_filename,  # TODO: something smarter should be here
        f_code.co_name,
        f_code.co_firstlineno,
        f_code.co_lnotab,
        )
    logging.info(f"resulting morph:\n{str(code)}")
    return result


def morph_stack(frame_data, globals=True, **kwargs):
    """
    Morphs the stack.

    Parameters
    ----------
    frame_data : list
        States of all individual frames.
    kwargs
        Arguments to morph_execpoint.

    Returns
    -------
    result : CodeType
        The resulting morph for the root frame.
    """
    prev = None
    for i, frame in enumerate(frame_data):
        logging.info(f"Preparing morph #{i:d}")
        prev = morph_execpoint(frame, prev, globals=globals and frame is frame_data[-1], **kwargs)
    return prev


def dump(file, **kwargs):
    """
    Serialize the runtime into a file and exit.

    Parameters
    ----------
    file : File
        The file to write to.
    kwargs
        Arguments to `dill.dump`.
    """
    def serializer(stack_data):
        dill.dump(FunctionType(morph_stack(stack_data), globals()), file, **kwargs)
    return snapshot(
        inspect.currentframe().f_back,
        finalize=serializer,
    )
load = dill.load


def bash_inline_create_file(name, contents):
    """
    Turns a file into bash command.

    Parameters
    ----------
    name : str
        File name.
    contents : bytes
        File contents.

    Returns
    -------
    result : str
        The resulting command that creates this file.
    """
    return f"echo {quote(base64.b64encode(contents).decode())} | base64 -d > {quote(name)}"


def shell_teleport(*shell_args, python="python", before="cd $(mktemp -d)",
        pyc_fn="payload.pyc", shell_delimeter="; ", pack_file=bash_inline_create_file,
        pack_object=dill.dumps, unpack_object=("dill", "loads"),
        _frame=None, **kwargs):
    """
    Teleport into another shell.

    Parameters
    ----------
    shell_args
        Arguments to a shell where python is found.
    python : str
        Python executable in the shell.
    before : str, list
        Shell commands to be run before anything else.
    pyc_fn : str
        Temporary filename to save the bytecode to.
    shell_delimeter : str
        Shell delimeter to chain multiple commands.
    pack_file : Callable
        A function `f(name, contents)` turning a file
        into shell-friendly assembly.
    pack_object : Callable, None
        A method turning objects into bytes (serializer)
        locally.
    unpack_object : tuple, None
        A 2-tuple `(module_name, method_name)` specifying
        the method that morph uses to unpack the data.
    _frame
        The frame to collect.
    kwargs
        Other arguments to `subprocess.run`.

    Returns
    -------
    None
    """
    payload = []
    if not isinstance(before, (list, tuple)):
        payload.append(before)
    else:
        payload.extend(before)

    def _teleport(stack_data):
        """Will be executed after the snapshot is done."""
        logging.info("Snapshot done, composing morph ...")
        code = morph_stack(stack_data, pack=pack_object, unpack=unpack_object)  # compose the code object
        logging.info("Creating pyc ...")
        files = {pyc_fn: _code_to_timestamp_pyc(code)}  # turn it into pyc
        for k, v in files.items():
            payload.append(pack_file(k, v))  # turn files into shell commands
        payload.append(f"{python} {pyc_fn}")  # execute python

        # pipe the output and exit
        logging.info("Executing the payload ...")
        p = subprocess.run([*shell_args, shell_delimeter.join(payload)], text=True, **kwargs)
        exit(p.returncode)

    # proceed to snapshotting
    return snapshot(
        inspect.currentframe().f_back if _frame is None else _frame,
        finalize=_teleport,
    )
bash_teleport = shell_teleport


def dummy_teleport(**kwargs):
    """A dummy teleport into another python process in current environment."""
    return bash_teleport("bash", "-c", _frame=inspect.currentframe().f_back, **kwargs)


