"""
Metadata Utilities for State Persistence (v0.2.0)
Implements encryption, differential encoding, decoy injection, and MAC
"""

import numpy as np
from typing import Dict, Any, Union, List
import time
import secrets

from core.cel.cel import ContinuousEntropyLattice
from core.phe.phe import ProbabilisticHashingEngine
from utils.tlv_format import serialize_metadata_tlv, deserialize_metadata_tlv, METADATA_VERSION_TLV


def compute_metadata_mac(metadata_bytes: bytes, password: str) -> bytes:
    """
    Compute PHE-based MAC for metadata integrity (v0.2.0)
    
    Args:
        metadata_bytes: Serialized metadata
        password: Password for MAC generation
        
    Returns:
        128-bit MAC
    """
    phe = ProbabilisticHashingEngine()
    
    # Bind password to PHE context
    mac_material = password.encode('utf-8') + b'||metadata||' + metadata_bytes
    
    # Generate MAC using PHE digest
    mac = phe.digest(mac_material, context={'purpose': 'metadata_authentication'})
    
    return mac[:16]  # 128-bit MAC


def verify_metadata_mac(metadata_bytes: bytes, password: str, stored_mac: bytes) -> bool:
    """
    Verify metadata MAC (v0.2.0)
    
    Args:
        metadata_bytes: Serialized metadata
        password: Password used for MAC
        stored_mac: Stored MAC to verify against
        
    Returns:
        True if MAC is valid
    """
    computed_mac = compute_metadata_mac(metadata_bytes, password)
    
    # Constant-time comparison
    return secrets.compare_digest(computed_mac, stored_mac)


def differential_encode_cel_snapshot(
    snapshot: Dict[str, Any],
    seed: Union[str, bytes, int]
) -> Dict[str, Any]:
    """
    Encode CEL snapshot as deltas from seed-initialized state (v0.2.0)
    Achieves 70-90% compression
    
    Args:
        snapshot: CEL snapshot from CEL.snapshot()
        seed: Original seed used for CEL initialization
        
    Returns:
        Differential-encoded snapshot
    """
    # Reconstruct reference CEL state from seed
    reference_cel = ContinuousEntropyLattice(
        lattice_size=snapshot.get('lattice_size', 256),
        depth=snapshot.get('depth', 8)
    )
    reference_cel.init(seed)
    reference_lattice = reference_cel.lattice
    
    # Compute deltas
    deltas = []
    current_lattice = snapshot.get('lattice')
    
    if current_lattice is not None and reference_lattice is not None:
        depth = snapshot.get('depth', 8)
        lattice_size = snapshot.get('lattice_size', 256)
        
        for layer in range(depth):
            for row in range(lattice_size):
                for col in range(lattice_size):
                    current_val = int(current_lattice[layer][row][col])
                    reference_val = int(reference_lattice[layer][row][col])
                    
                    delta = current_val - reference_val
                    
                    # Only store non-zero deltas
                    if delta != 0:
                        deltas.append((layer, row, col, delta))
    
    return {
        'differential': True,
        'deltas': deltas,
        'num_deltas': len(deltas),
        'lattice_size': snapshot.get('lattice_size', 256),
        'depth': snapshot.get('depth', 8),
        'seed_fingerprint': snapshot.get('seed_fingerprint', 0),
        'operation_count': snapshot.get('operation_count', 0),
        'state_version': snapshot.get('state_version', 0)
    }


def differential_decode_cel_snapshot(
    encoded: Dict[str, Any],
    seed: Union[str, bytes, int]
) -> Dict[str, Any]:
    """
    Reconstruct CEL snapshot from deltas (v0.2.0)
    
    Args:
        encoded: Differential-encoded snapshot
        seed: Original seed
        
    Returns:
        Full CEL snapshot
    """
    # Reconstruct reference state
    reference_cel = ContinuousEntropyLattice(
        lattice_size=encoded.get('lattice_size', 256),
        depth=encoded.get('depth', 8)
    )
    reference_cel.init(seed)
    lattice = reference_cel.lattice.copy()
    
    # Apply deltas
    for layer, row, col, delta in encoded.get('deltas', []):
        lattice[layer][row][col] = int(lattice[layer][row][col]) + delta
    
    return {
        'lattice': lattice,
        'lattice_size': encoded.get('lattice_size', 256),
        'depth': encoded.get('depth', 8),
        'seed_fingerprint': encoded.get('seed_fingerprint', 0),
        'operation_count': encoded.get('operation_count', 0),
        'state_version': encoded.get('state_version', 0)
    }


def encrypt_metadata(
    metadata: Dict[str, Any],
    password: str,
    use_differential: bool = True,
    seed: Union[str, bytes, int, None] = None
) -> Dict[str, Any]:
    """
    Encrypt metadata with ephemeral CEL-derived key (v0.2.0)
    
    Args:
        metadata: Metadata to encrypt
        password: Password for encryption
        use_differential: Whether to use differential encoding for CEL snapshot
        seed: Original seed (required for differential encoding)
        
    Returns:
        Encrypted metadata package
    """
    # Generate timestamp-based ephemeral seed
    timestamp_seed = int(time.time() * 1_000_000) % (2**32)
    
    # Apply differential encoding if requested
    if use_differential and 'cel_snapshot' in metadata and seed is not None:
        metadata = metadata.copy()
        metadata['cel_snapshot'] = differential_encode_cel_snapshot(
            metadata['cel_snapshot'],
            seed
        )
    
    # Serialize metadata to TLV
    metadata_bytes = serialize_metadata_tlv(metadata, version=METADATA_VERSION_TLV)
    
    # Compute MAC before encryption
    mac = compute_metadata_mac(metadata_bytes, password)
    
    # Derive ephemeral key deterministically using HKDF-like construction
    # Note: Using PHE instead of CEL to ensure deterministic key derivation
    key_material = password.encode('utf-8') + b'||metadata||' + timestamp_seed.to_bytes(4, 'big')
    
    phe_kdf = ProbabilisticHashingEngine()
    key_hash = phe_kdf.digest(key_material, context={'purpose': 'metadata_encryption_key'})
    encryption_key = key_hash[:32]  # 256 bits
    
    # Encrypt with CEL-derived key stream (modular addition)
    encrypted = bytearray()
    for i, byte in enumerate(metadata_bytes):
        key_byte = int(encryption_key[i % len(encryption_key)]) % 256
        encrypted_byte = (byte + key_byte) % 256
        encrypted.append(encrypted_byte)
    
    return {
        'encrypted_metadata': bytes(encrypted),
        'ephemeral_seed': timestamp_seed,
        'metadata_mac': mac,
        'differential_encoded': use_differential
    }


def decrypt_metadata(
    encrypted_data: Dict[str, Any],
    password: str,
    seed: Union[str, bytes, int, None] = None
) -> Dict[str, Any]:
    """
    Decrypt metadata with ephemeral CEL-derived key (v0.2.0)
    
    Args:
        encrypted_data: Encrypted metadata package
        password: Password for decryption
        seed: Original seed (required if differential encoding was used)
        
    Returns:
        Decrypted metadata
        
    Raises:
        ValueError: If MAC verification fails
    """
    # Reconstruct ephemeral key deterministically from stored seed
    timestamp_seed = encrypted_data['ephemeral_seed']
    key_material = password.encode('utf-8') + b'||metadata||' + timestamp_seed.to_bytes(4, 'big')
    
    phe_kdf = ProbabilisticHashingEngine()
    key_hash = phe_kdf.digest(key_material, context={'purpose': 'metadata_encryption_key'})
    encryption_key = key_hash[:32]  # 256 bits
    
    # Decrypt (modular subtraction)
    encrypted_bytes = encrypted_data['encrypted_metadata']
    decrypted = bytearray()
    for i, byte in enumerate(encrypted_bytes):
        key_byte = int(encryption_key[i % len(encryption_key)]) % 256
        decrypted_byte = (byte - key_byte) % 256
        decrypted.append(decrypted_byte)
    
    metadata_bytes = bytes(decrypted)
    
    # Verify MAC
    if 'metadata_mac' in encrypted_data:
        if not verify_metadata_mac(metadata_bytes, password, encrypted_data['metadata_mac']):
            raise ValueError("Metadata MAC verification failed - tampering detected or wrong password")
    
    # Deserialize TLV
    metadata = deserialize_metadata_tlv(metadata_bytes)
    
    # Decode differential encoding if present
    if encrypted_data.get('differential_encoded') and seed is not None:
        if 'cel_snapshot' in metadata and metadata['cel_snapshot'].get('differential'):
            metadata['cel_snapshot'] = differential_decode_cel_snapshot(
                metadata['cel_snapshot'],
                seed
            )
    
    return metadata


def inject_decoy_vectors(
    real_metadata: Dict[str, Any],
    password: str,
    num_decoys: int = 3
) -> Dict[str, Any]:
    """
    Generate decoy snapshots and interleave with real metadata (v0.2.0)
    
    Args:
        real_metadata: Real encrypted metadata
        password: Password (used to determine real index)
        num_decoys: Number of decoy snapshots (3-5)
        
    Returns:
        Obfuscated metadata with decoys
    """
    decoys = []
    
    for i in range(num_decoys):
        # Generate fake seed
        fake_seed = secrets.token_bytes(32)
        
        # Create fake CEL snapshot
        fake_cel = ContinuousEntropyLattice(lattice_size=64, depth=4)  # Smaller for decoys
        fake_cel.init(fake_seed)
        fake_cel.update({'operation': f'decoy_{i}'})
        fake_snapshot = fake_cel.snapshot()
        
        # Encrypt with different ephemeral key (wrong password derivative)
        fake_password = f"decoy_{i}_{password}_fake"
        fake_encrypted = encrypt_metadata(
            {'cel_snapshot': fake_snapshot},
            fake_password,
            use_differential=False
        )
        
        decoys.append(fake_encrypted)
    
    # Determine real index from password hash
    phe = ProbabilisticHashingEngine()
    phe_hash = phe.digest(password.encode('utf-8'))
    real_index = int.from_bytes(phe_hash[-4:], 'big') % (num_decoys + 1)
    
    # Interleave decoys with real metadata
    all_vectors = decoys[:real_index] + [real_metadata] + decoys[real_index:]
    
    return {
        'vectors': all_vectors,
        'num_vectors': len(all_vectors),
        # Do NOT store real_index - must derive it during decryption
    }


def extract_real_vector(
    obfuscated: Dict[str, Any],
    password: str
) -> Dict[str, Any]:
    """
    Extract real metadata from decoys using password-derived index (v0.2.0)
    
    Args:
        obfuscated: Obfuscated metadata with decoys
        password: Password
        
    Returns:
        Real encrypted metadata
    """
    # Derive real index from password
    phe = ProbabilisticHashingEngine()
    phe_hash = phe.digest(password.encode('utf-8'))
    real_index = int.from_bytes(phe_hash[-4:], 'big') % obfuscated['num_vectors']
    
    # Extract real encrypted metadata
    real_encrypted = obfuscated['vectors'][real_index]
    
    return real_encrypted
