import codecs
import copy
import functools
import hashlib
import json
from typing import Any, Dict, List, Optional, Pattern, Set, TextIO, Tuple, Type

from keylime import config, keylime_logging, signing
from keylime.agentstates import AgentAttestState
from keylime.common import algorithms, validators
from keylime.common.algorithms import Hash
from keylime.failure import Component, Failure
from keylime.ima import ast, file_signatures, ima_dm
from keylime.ima.file_signatures import ImaKeyrings

logger = keylime_logging.init_logging("ima")


# The version of the IMA policy format that is supported by this keylime release
RUNTIME_POLICY_CURRENT_VERSION = 1


class RUNTIME_POLICY_GENERATOR:
    Unknown = 0
    EmptyAllowList = 1
    CompatibleAllowList = 2
    LegacyAllowList = 3


# A correctly formatted empty IMA policy.
#
# Older versions of Keylime allowlists did not have the "log_hash_alg"
# parameter. Hard-coding it to "sha1" is perfectly fine, and the fact one
# specifies a different algorithm on the kernel command line (e.g., ima_hash=sha256)
# does not affect normal operation of Keylime, since it does not validate the
# hash algorithm received from agent's IMA runtime measurements.
# The only situation where this hard-coding would become a problem is if and when
# the kernel maintainers decide to use a different algorithm for template-hash.
EMPTY_RUNTIME_POLICY = {
    "meta": {
        "version": RUNTIME_POLICY_CURRENT_VERSION,
        "generator": RUNTIME_POLICY_GENERATOR.EmptyAllowList,
    },
    "release": 0,
    "digests": {},
    "excludes": [],
    "keyrings": {},
    "ima": {"ignored_keyrings": [], "log_hash_alg": "sha1", "dm_policy": None},
    "ima-buf": {},
    "verification-keys": "",
}


class IMAMeasurementList:
    """IMAMeasurementList models the IMA measurement lists's last known
    two numbers of entries and filesizes
    """

    instance = None
    entries: Set[Tuple[int, int]]

    @staticmethod
    def get_instance() -> "IMAMeasurementList":
        """Return a singleton"""
        if not IMAMeasurementList.instance:
            IMAMeasurementList.instance = IMAMeasurementList()
        return IMAMeasurementList.instance

    def __init__(self) -> None:
        """Constructor"""
        self.entries = set()
        self.reset()

    def reset(self) -> None:
        """Reset the variables"""
        self.entries = set()

    def update(self, num_entries: int, filesize: int) -> None:
        """Update the number of entries and current filesize of the log."""
        if len(self.entries) > 256:
            for entry in self.entries:
                self.entries.discard(entry)
                break
        self.entries.add((num_entries, filesize))

    def find(self, nth_entry: int) -> Tuple[int, int]:
        """Find the closest entry to the n-th entry and return its number
        and filesize to seek to, return 0, 0 if nothing was found.
        """
        best = (0, 0)
        for entry in self.entries:
            if entry[0] > best[0] and entry[0] <= nth_entry:
                best = entry
        return best


def read_measurement_list(ima_log_file: TextIO, nth_entry: int) -> Tuple[Optional[str], int, int]:
    """Read the IMA measurement list starting from a given entry.
    The entry may be of any value 0 <= entry <= entries_in_log where
    entries_in_log + 1 indicates that the client wants to read the next entry
    once available. If the entry is outside this range, the function will
    automatically read from the 0-th entry.
    This function returns the measurement list and the entry from where it
    was read and the current number of entries in the file.
    """
    IMAML = IMAMeasurementList.get_instance()
    ml = None

    # Try to find the closest entry to the nth_entry
    num_entries, filesize = IMAML.find(nth_entry)

    if not ima_log_file:
        IMAML.reset()
        nth_entry = 0
        logger.warning("IMA measurement list not available: %s", config.IMA_ML)
    else:
        ima_log_file.seek(filesize)
        filedata = ima_log_file.read()
        # filedata now corresponds to starting list at entry number 'IMAML.num_entries'
        # find n-th entry and determine number of total entries in file now
        offset = 0
        while True:
            try:
                if nth_entry == num_entries:
                    ml = filedata[offset:]
                o = filedata.index("\n", offset)
                offset = o + 1
                num_entries += 1
            except ValueError:
                break
        # IMAML.filesize corresponds to position for entry number 'IMAML.num_entries'
        IMAML.update(num_entries, filesize + offset)

        # Nothing found? User request beyond next-expected entry.
        # Start over with entry 0. This cannot recurse again.
        if ml is None:
            return read_measurement_list(ima_log_file, 0)

    return ml, nth_entry, num_entries


def _validate_ima_ng(
    exclude_regex: Optional[Pattern[str]],
    allowlist: Optional[Dict[str, Any]],
    digest: ast.Digest,
    path: ast.Name,
    hash_types: str = "digests",
) -> Failure:
    failure = Failure(Component.IMA, ["validation", "ima-ng"])
    if allowlist is not None:
        if exclude_regex is not None and exclude_regex.match(path.name):
            logger.debug("IMA: ignoring excluded path %s", path)
            return failure

        accept_list = allowlist[hash_types].get(path.name, None)
        if accept_list is None:
            logger.warning("File not found in allowlist: %s", path.name)
            failure.add_event("not_in_allowlist", f"File not found in allowlist: {path.name}", True)
            return failure

        if codecs.encode(digest.hash, "hex").decode("utf-8") not in accept_list:
            logger.warning(
                "Hashes for file %s don't match %s not in %s",
                path.name,
                codecs.encode(digest.hash, "hex").decode("utf-8"),
                accept_list,
            )
            failure.add_event(
                "allowlist_hash",
                {
                    "message": "Hash not in allowlist found",
                    "got": codecs.encode(digest.hash, "hex").decode("utf-8"),
                    "expected": accept_list,
                },
                True,
            )
            return failure

    return failure


def _validate_ima_sig(
    exclude_regex: Optional[Pattern[str]],
    ima_keyrings: Optional[file_signatures.ImaKeyrings],
    allowlist: Optional[Dict[str, Any]],
    digest: ast.Digest,
    path: ast.Name,
    signature: ast.Signature,
) -> Failure:
    failure = Failure(Component.IMA, ["validator", "ima-sig"])
    valid_signature = False
    if ima_keyrings and signature:
        if exclude_regex is not None and exclude_regex.match(path.name):
            logger.debug("IMA: ignoring excluded path %s", path.name)
            return failure

        if not ima_keyrings.integrity_digsig_verify(signature.data, digest.hash, digest.algorithm):
            logger.warning("signature for file %s is not valid", path.name)
            failure.add_event("invalid_signature", f"signature for file {path.name} is not valid", True)
            return failure

        valid_signature = True
        logger.debug("signature for file %s is good", path)

    # If there is also an allowlist verify the file against that but only do this if:
    # - we did not evaluate the signature (valid_siganture = False)
    # - the signature is valid and the file is also in the allowlist
    if (
        allowlist is not None
        and allowlist.get("digests") is not None
        and ((allowlist["digests"].get(path.name, None) is not None and valid_signature) or not valid_signature)
    ):
        # We use the normal ima_ng validator to validate hash
        return _validate_ima_ng(exclude_regex, allowlist, digest, path)

    # If we don't have a allowlist and don't have a keyring we just ignore the validation.
    if ima_keyrings is None:
        return failure

    if not valid_signature:
        failure.add_event("invalid_signature", f"signature for file {path.name} could not be validated", True)
    return failure


def _validate_ima_buf(
    exclude_regex: Optional[Pattern[str]],
    allowlist: Optional[Dict[str, Any]],
    ima_keyrings: Optional[file_signatures.ImaKeyrings],
    dm_validator: Optional[ima_dm.DmIMAValidator],
    digest: ast.Digest,
    path: ast.Name,
    data: ast.Buffer,
) -> Failure:
    failure = Failure(Component.IMA)
    # Is data.data a key?
    try:
        pubkey, keyidv2 = file_signatures.get_pubkey(data.data)
    except ValueError as ve:
        failure.add_event("invalid_key", f"key from {path.name} does not have a supported key: {ve}", True)
        return failure

    if pubkey:
        ignored_keyrings = []
        if allowlist:
            ignored_keyrings = allowlist.get("ima", {}).get("ignored_keyrings", [])

        if "*" not in ignored_keyrings and path.name not in ignored_keyrings:
            failure = _validate_ima_ng(exclude_regex, allowlist, digest, path, hash_types="keyrings")
            if not failure:
                # Add the key only now that it's validated (no failure)
                if ima_keyrings is not None:
                    ima_keyrings.add_pubkey_to_keyring(pubkey, path.name, keyidv2=keyidv2)
    # Check if this is a device mapper entry only if we have a validator for that
    elif dm_validator is not None and path.name in dm_validator.valid_names:
        failure = dm_validator.validate(digest, path, data)
    else:
        # handling of generic ima-buf entries that for example carry a hash in the buf field
        failure = _validate_ima_ng(exclude_regex, allowlist, digest, path, hash_types="ima-buf")

    # Anything else evaluates to true for now
    return failure


def _process_measurement_list(
    agentAttestState: AgentAttestState,
    lines: List[str],
    hash_alg: Hash,
    runtime_policy: Optional[Dict[str, Any]] = None,
    pcrval: Optional[str] = None,
    ima_keyrings: Optional[ImaKeyrings] = None,
    boot_aggregates: Optional[Dict[str, List[str]]] = None,
) -> Tuple[str, Failure]:
    failure = Failure(Component.IMA)
    running_hash = agentAttestState.get_pcr_state(config.IMA_PCR, hash_alg)
    assert running_hash

    found_pcr = pcrval is None
    errors: Dict[Type[ast.Mode], int] = {}
    pcrval_bytes = b""
    if pcrval is not None:
        pcrval_bytes = codecs.decode(pcrval.encode("utf-8"), "hex")

    if runtime_policy is not None:
        exclude_list = runtime_policy.get("excludes")
    else:
        exclude_list = None

    ima_log_hash_alg = algorithms.Hash.SHA1
    if runtime_policy is not None:
        try:
            ima_log_hash_alg = algorithms.Hash(runtime_policy["ima"]["log_hash_alg"])
        except ValueError:
            logger.warning(
                "Specified IMA log hash algorithm %s is not a valid algorithm! Defaulting to SHA1.",
                runtime_policy["ima"]["log_hash_alg"],
            )

    if boot_aggregates and runtime_policy:
        if "boot_aggregate" not in runtime_policy["digests"]:
            runtime_policy["digests"]["boot_aggregate"] = []
        for alg in boot_aggregates.keys():
            for val in boot_aggregates[alg]:
                if val not in runtime_policy["digests"]["boot_aggregate"]:
                    runtime_policy["digests"]["boot_aggregate"].append(val)

    exclude_list_compiled_regex, err_msg = validators.valid_exclude_list(exclude_list)
    if err_msg:
        # This should not happen as the exclude list has already been validated
        # by the verifier before acceping it. This is a safety net just in case.
        err_msg += " Exclude list will be ignored."
        logger.error(err_msg)

    # Setup device mapper validation
    dm_validator = None
    if runtime_policy is not None:
        dm_policy = runtime_policy["ima"]["dm_policy"]

        if dm_policy is not None:
            dm_validator = ima_dm.DmIMAValidator(dm_policy)
            dm_state = agentAttestState.get_ima_dm_state()
            # Only load state when using incremental attestation
            if agentAttestState.get_next_ima_ml_entry() != 0:
                dm_validator.state_load(dm_state)

    ima_validator = ast.Validator(
        {
            ast.ImaSig: functools.partial(_validate_ima_sig, exclude_list_compiled_regex, ima_keyrings, runtime_policy),
            ast.ImaNg: functools.partial(_validate_ima_ng, exclude_list_compiled_regex, runtime_policy),
            ast.Ima: functools.partial(_validate_ima_ng, exclude_list_compiled_regex, runtime_policy),
            ast.ImaBuf: functools.partial(
                _validate_ima_buf, exclude_list_compiled_regex, runtime_policy, ima_keyrings, dm_validator
            ),
        }
    )

    # Iterative attestation may send us no log [len(lines) == 1]; compare last know PCR 10 state
    # against current PCR state.
    # Since IMA's append to the log and PCR extend as well as Keylime's retrieval of the quote, reading
    # of PCR 10 and retrieval of the log are not atomic, we may get a quote that does not yet take into
    # account the next-appended measurements' [len(lines) >= 2] PCR extension(s). In fact, the value of
    # the PCR may lag the log by several entries.
    if not found_pcr:
        found_pcr = running_hash == pcrval_bytes

    for linenum, line in enumerate(lines):
        # remove only the newline character, as there can be the space
        # as the delimiter character followed by an empty field at the
        # end
        line = line.strip("\n")
        if line == "":
            continue

        try:
            entry = ast.Entry(line, ima_validator, ima_hash_alg=ima_log_hash_alg, pcr_hash_alg=hash_alg)

            # update hash
            running_hash = hash_alg.hash(running_hash + entry.pcr_template_hash)

            validation_failure = entry.invalid()

            if validation_failure:
                failure.merge(validation_failure)
                errors[type(entry.mode)] = errors.get(type(entry.mode), 0) + 1

            if not found_pcr:
                # End of list should equal pcr value
                found_pcr = running_hash == pcrval_bytes
                if found_pcr:
                    logger.debug("Found match at linenum %s", linenum + 1)
                    # We always want to have the very last line for the attestation, so
                    # we keep the previous runninghash, which is not the last one!
                    agentAttestState.update_ima_attestation(int(entry.pcr), running_hash, linenum + 1)
                    if dm_validator:
                        agentAttestState.set_ima_dm_state(dm_validator.state_dump())

        except ast.ParserError:
            failure.add_event("entry", f"Line was not parsable into a valid IMA entry: {line}", True, ["parser"])
            logger.error("Line was not parsable into a valid IMA entry: %s", line)

    # check PCR value has been found
    if not found_pcr:
        logger.error("IMA measurement list does not match TPM PCR %s", pcrval)
        failure.add_event("pcr_mismatch", f"IMA measurement list does not match TPM PCR {pcrval}", True)

    # Check if any validators failed
    if sum(errors.values()) > 0:
        error_msg = "IMA ERRORS: Some entries couldn't be validated. Number of failures in modes: "
        error_msg += ", ".join([f"{k.__name__ } {v}" for k, v in errors.items()])
        logger.error("%s.", error_msg)

    return codecs.encode(running_hash, "hex").decode("utf-8"), failure


def process_measurement_list(
    agentAttestState: AgentAttestState,
    lines: List[str],
    runtime_policy: Optional[Dict[str, Any]] = None,
    pcrval: Optional[str] = None,
    ima_keyrings: Optional[ImaKeyrings] = None,
    boot_aggregates: Optional[Dict[str, List[str]]] = None,
    hash_alg: algorithms.Hash = algorithms.Hash.SHA1,
) -> Tuple[str, Failure]:
    failure = Failure(Component.IMA)
    try:
        running_hash, failure = _process_measurement_list(
            agentAttestState,
            lines,
            hash_alg,
            runtime_policy=runtime_policy,
            pcrval=pcrval,
            ima_keyrings=ima_keyrings,
            boot_aggregates=boot_aggregates,
        )
    except:  # pylint: disable=try-except-raise
        raise
    finally:
        if failure:
            # TODO currently reset on any failure which might be an issue
            agentAttestState.reset_ima_attestation()

    return running_hash, failure


# Read IMA policy files from disk, validate signatures and checksums, and prepare for sending.
def read_runtime_policy(
    alist: Optional[str] = None,
    checksum: Optional[str] = "",
    al_sig_file: Optional[str] = None,
    al_key_file: Optional[str] = None,
) -> Tuple[bytes, bytes, bytes]:
    al_key = b""
    al_sig = b""
    verify_signature = False

    # If user only wants signatures then a runtime policy is not required
    if alist is None or alist == "":
        alist_bytes = json.dumps(copy.deepcopy(EMPTY_RUNTIME_POLICY)).encode()

    elif isinstance(alist, str):
        with open(alist, "rb") as alist_f:
            logger.debug("Loading runtime policy from %s", alist)
            alist_bytes = alist_f.read()

    else:
        raise Exception("Invalid runtime policy provided")

    # Load signatures/keys if needed
    if al_sig_file and al_key_file:
        logger.debug(
            "Loading key (%s) and signature (%s) checking against runtime policy (%s)", al_key_file, al_sig_file, alist
        )
        verify_signature = True
        with open(al_key_file, "rb") as key_f:
            al_key = key_f.read()
        with open(al_sig_file, "rb") as sig_f:
            al_sig = sig_f.read()

    # Verify runtime policy. This function checks for correct JSON formatting, and
    # will also verify signatures if provided.
    try:
        verify_runtime_policy(alist_bytes, al_key, al_sig, verify_sig=verify_signature)
    except ImaValidationError as error:
        message = f"Validation for runtime policy {alist} failed! Error: {error.message}"
        raise Exception(message) from error

    sha256 = hashlib.sha256()
    sha256.update(alist_bytes)
    calculated_checksum = sha256.hexdigest()
    logger.debug("Loaded runtime policy from %s with checksum %s", alist, calculated_checksum)

    if checksum:
        if checksum == calculated_checksum:
            logger.debug("Runtime policy passed checksum validation")
        else:
            raise Exception(
                f"Checksum of runtime policy does not match! Expected {checksum}, Calculated {calculated_checksum}"
            )

    return alist_bytes, al_key, al_sig


def runtime_policy_db_contents(runtime_policy_name: str, runtime_policy: str, tpm_policy: str = "") -> Dict[str, Any]:
    """Assembles a runtime policy dictionary to be written on the database"""
    runtime_policy_db_format: Dict[str, Any] = {}
    runtime_policy_db_format["name"] = runtime_policy_name
    # TODO: This was required to ensure e2e CI tests pass
    if runtime_policy == "{}":
        runtime_policy_db_format["ima_policy"] = None
    else:
        runtime_policy_db_format["ima_policy"] = runtime_policy
    runtime_policy_bytes = runtime_policy.encode()
    runtime_policy_dict = deserialize_runtime_policy(runtime_policy)
    if "meta" in runtime_policy_dict:
        if "generator" in runtime_policy_dict["meta"]:
            runtime_policy_db_format["generator"] = runtime_policy_dict["meta"]["generator"]
        else:
            runtime_policy_db_format["generator"] = RUNTIME_POLICY_GENERATOR.Unknown
    if tpm_policy:
        runtime_policy_db_format["tpm_policy"] = tpm_policy

    runtime_policy_db_format["checksum"] = hashlib.sha256(runtime_policy_bytes).hexdigest()

    return runtime_policy_db_format


class ImaValidationError(Exception):
    def __init__(self, message: str, code: int):
        self.message = message
        self.code = code
        super().__init__(self.message)


def verify_runtime_policy(
    runtime_policy: bytes,
    runtime_policy_key: Optional[bytes] = None,
    runtime_policy_sig: Optional[bytes] = None,
    verify_sig: Optional[bool] = True,
) -> None:
    """
    Verify that a runtime policy is valid. If provided runtime policy has a detached signature, verify the signature.
    """

    if runtime_policy is None:
        raise ImaValidationError(
            message="No IMA policy provided!",
            code=400,
        )

    if verify_sig:
        if not runtime_policy_sig:
            raise ImaValidationError(
                message="Verifier is enforcing signature validation, but no signature was provided.", code=405
            )
        if not runtime_policy_key:
            raise ImaValidationError(
                message="Verifier is enforcing signature validation, but no public key was provided.", code=405
            )

        if not signing.verify_signature(runtime_policy_key, runtime_policy_sig, runtime_policy):
            raise ImaValidationError(message="Runtime policy failed detached signature verification!", code=401)
        logger.info("Runtime policy passed detached signature verification")

    # validate that the allowlist is proper JSON
    try:
        lists = json.loads(runtime_policy)
    except Exception as error:
        raise ImaValidationError(message="Runtime policy is not valid JSON!", code=400) from error

    # Validate exclude list contains valid regular expressions
    _, excl_err_msg = validators.valid_exclude_list(lists.get("exclude"))
    if excl_err_msg:
        raise ImaValidationError(
            message=f"{excl_err_msg} Exclude list regex is misformatted. Please correct the issue and try again.",
            code=400,
        )


def deserialize_runtime_policy(runtime_policy: str) -> Dict[str, Any]:
    """
    Converts policies stored in the database to JSON (if applicable), for use in code.
    """

    # TODO: Extract IMA policy JSON from DSSE envelope, if applicable.
    runtime_policy_deserialized: Dict[str, Any] = json.loads(runtime_policy)
    return runtime_policy_deserialized
