# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import struct
import builtins
from typing import Iterator

MIN_STACKSTRING_LEN = 8


def xor_static(data: bytes, i: int) -> bytes:
    return bytes(c ^ i for c in data)


def is_aw_function(symbol: str) -> bool:
    """
    is the given function name an A/W function?
    these are variants of functions that, on Windows, accept either a narrow or wide string.
    """
    if len(symbol) < 2:
        return False

    # last character should be 'A' or 'W'
    if symbol[-1] not in ("A", "W"):
        return False

    return True


def is_ordinal(symbol: str) -> bool:
    """
    is the given symbol an ordinal that is prefixed by "#"?
    """
    if symbol:
        return symbol[0] == "#"
    return False


def generate_symbols(dll: str, symbol: str, include_dll=False) -> Iterator[str]:
    """
    for a given dll and symbol name, generate variants.
    we over-generate features to make matching easier.
    these include:
      - CreateFileA
      - CreateFile
      - ws2_32.#1

    note that since capa v7 only `import` features and APIs called via ordinal include DLL names:
      - kernel32.CreateFileA
      - kernel32.CreateFile
      - ws2_32.#1

    for `api` features dll names are good for documentation but not used during matching
    """
    # normalize dll name
    dll = dll.lower()

    # trim extensions observed in dynamic traces
    dll = dll[0:-4] if dll.endswith(".dll") else dll
    dll = dll[0:-4] if dll.endswith(".drv") else dll
    dll = dll[0:-3] if dll.endswith(".so") else dll

    if include_dll or is_ordinal(symbol):
        # ws2_32.#1
        # kernel32.CreateFileA
        yield f"{dll}.{symbol}"

    if not is_ordinal(symbol):
        # CreateFileA
        yield symbol

        if is_aw_function(symbol):
            if include_dll:
                # kernel32.CreateFile
                yield f"{dll}.{symbol[:-1]}"

            # CreateFile
            yield symbol[:-1]


def reformat_forwarded_export_name(forwarded_name: str) -> str:
    """
    a forwarded export has a DLL name/path and symbol name.
    we want the former to be lowercase, and the latter to be verbatim.
    """

    # use rpartition so we can split on separator between dll and name.
    # the dll name can be a full path, like in the case of
    # ef64d6d7c34250af8e21a10feb931c9b
    # which i assume means the path can have embedded periods.
    # so we don't want the first period, we want the last.
    forwarded_dll, _, forwarded_symbol = forwarded_name.rpartition(".")
    forwarded_dll = forwarded_dll.lower()

    return f"{forwarded_dll}.{forwarded_symbol}"


def all_zeros(bytez: bytes) -> bool:
    return all(b == 0 for b in builtins.bytes(bytez))


def twos_complement(val: int, bits: int) -> int:
    """
    compute the 2's complement of int value val

    from: https://stackoverflow.com/a/9147327/87207
    """
    # if sign bit is set e.g., 8bit: 128-255
    if (val & (1 << (bits - 1))) != 0:
        # compute negative value
        return val - (1 << bits)
    else:
        # return positive value as is
        return val


def carve_pe(pbytes: bytes, offset: int = 0) -> Iterator[tuple[int, int]]:
    """
    Generate (offset, key) tuples of embedded PEs

    Based on the version from vivisect:
      https://github.com/vivisect/vivisect/blob/7be4037b1cecc4551b397f840405a1fc606f9b53/PE/carve.py#L19
    And its IDA adaptation:
      capa/features/extractors/ida/file.py
    """
    mz_xor = [
        (
            xor_static(b"MZ", key),
            xor_static(b"PE", key),
            key,
        )
        for key in range(256)
    ]

    pblen = len(pbytes)
    todo = [(pbytes.find(mzx, offset), mzx, pex, key) for mzx, pex, key in mz_xor]
    todo = [(off, mzx, pex, key) for (off, mzx, pex, key) in todo if off != -1]

    while len(todo):
        off, mzx, pex, key = todo.pop()

        # The MZ header has one field we will check
        # e_lfanew is at 0x3c
        e_lfanew = off + 0x3C
        if pblen < (e_lfanew + 4):
            continue

        newoff = struct.unpack("<I", xor_static(pbytes[e_lfanew : e_lfanew + 4], key))[0]

        nextres = pbytes.find(mzx, off + 1)
        if nextres != -1:
            todo.append((nextres, mzx, pex, key))

        peoff = off + newoff
        if pblen < (peoff + 2):
            continue

        if pbytes[peoff : peoff + 2] == pex:
            yield (off, key)
