"""
TLV (Type-Length-Value) Binary Format for Metadata Serialization
Replaces JSON for ~25% size reduction and versioning support

Format: [Type:1 byte][Length:4 bytes big-endian][Value:variable]
"""

import numpy as np
from typing import Dict, Any, Union


# TLV Type Definitions
TLV_TYPE_METADATA_VERSION = 0x00
TLV_TYPE_CEL_SNAPSHOT = 0x01
TLV_TYPE_ORIGINAL_LENGTH = 0x02
TLV_TYPE_WAS_STRING = 0x03
TLV_TYPE_PHE_HASH = 0x04
TLV_TYPE_METADATA_MAC = 0x05
TLV_TYPE_DECOY_SNAPSHOT = 0x06
TLV_TYPE_DIFFERENTIAL_DELTAS = 0x07
TLV_TYPE_EPHEMERAL_SEED = 0x08
TLV_TYPE_CEL_SEED_FINGERPRINT = 0x09
TLV_TYPE_CEL_OPERATION_COUNT = 0x0A
TLV_TYPE_CEL_STATE_VERSION = 0x0B
TLV_TYPE_ENCRYPTED_METADATA = 0x0C  # Encrypted metadata blob
TLV_TYPE_DIFFERENTIAL_ENCODED = 0x0D  # Boolean flag
TLV_TYPE_OBFUSCATED = 0x0E  # Boolean flag for decoy presence
TLV_TYPE_NUM_VECTORS = 0x0F  # Number of obfuscated vectors
TLV_TYPE_VECTOR = 0x10  # Single obfuscated vector (encrypted metadata blob)

# Metadata format versions
METADATA_VERSION_JSON = 0x00  # v0.1.x JSON format
METADATA_VERSION_TLV = 0x01   # v0.2.0 TLV format


def serialize_metadata_tlv(metadata: Dict[str, Any], version: int = METADATA_VERSION_TLV) -> bytes:
    """
    Serialize metadata to TLV binary format
    
    Args:
        metadata: Metadata dictionary
        version: Metadata format version
        
    Returns:
        Binary TLV-encoded bytes
    """
    buffer = bytearray()
    
    # Always write version first
    _write_tlv_field(buffer, TLV_TYPE_METADATA_VERSION, version.to_bytes(1, 'big'))
    
    # Encrypted metadata blob (if present - from encrypt_metadata())
    if 'encrypted_metadata' in metadata:
        _write_tlv_field(buffer, TLV_TYPE_ENCRYPTED_METADATA, metadata['encrypted_metadata'])
    
    # Ephemeral seed
    if 'ephemeral_seed' in metadata:
        _write_tlv_field(buffer, TLV_TYPE_EPHEMERAL_SEED,
                        metadata['ephemeral_seed'].to_bytes(4, 'big'))
    
    # Metadata MAC
    if 'metadata_mac' in metadata:
        mac = metadata['metadata_mac']
        if isinstance(mac, str):
            mac = bytes.fromhex(mac)
        _write_tlv_field(buffer, TLV_TYPE_METADATA_MAC, mac)
    
    # Differential encoded flag
    if 'differential_encoded' in metadata:
        _write_tlv_field(buffer, TLV_TYPE_DIFFERENTIAL_ENCODED,
                        bytes([1 if metadata['differential_encoded'] else 0]))
    
    # Obfuscated flag
    if 'obfuscated' in metadata:
        _write_tlv_field(buffer, TLV_TYPE_OBFUSCATED,
                        bytes([1 if metadata['obfuscated'] else 0]))
    
    # Number of vectors
    if 'num_vectors' in metadata:
        _write_tlv_field(buffer, TLV_TYPE_NUM_VECTORS,
                        metadata['num_vectors'].to_bytes(4, 'big'))
    
    # Obfuscated vectors (real + decoys)
    if 'vectors' in metadata:
        for vector in metadata['vectors']:
            # Each vector is encrypted metadata - serialize recursively
            vector_data = serialize_metadata_tlv(vector, version)
            _write_tlv_field(buffer, TLV_TYPE_VECTOR, vector_data)
    
    # CEL snapshot (differential or full)
    if 'cel_snapshot' in metadata:
        cel_snapshot = metadata['cel_snapshot']
        
        # Check if differential-encoded
        if cel_snapshot.get('differential'):
            # Serialize header only (no lattice)
            snapshot_data = _serialize_cel_snapshot_header(cel_snapshot)
            _write_tlv_field(buffer, TLV_TYPE_CEL_SNAPSHOT, snapshot_data)
            
            # Serialize deltas separately
            deltas_data = _serialize_differential_deltas(cel_snapshot.get('deltas', []))
            _write_tlv_field(buffer, TLV_TYPE_DIFFERENTIAL_DELTAS, deltas_data)
        else:
            # Full snapshot with lattice
            snapshot_data = _serialize_cel_snapshot(cel_snapshot)
            _write_tlv_field(buffer, TLV_TYPE_CEL_SNAPSHOT, snapshot_data)
    
    # Differential deltas (if present at top level - backward compat)
    elif 'differential_deltas' in metadata:
        deltas_data = _serialize_differential_deltas(metadata['differential_deltas'])
        _write_tlv_field(buffer, TLV_TYPE_DIFFERENTIAL_DELTAS, deltas_data)
    
    # Original length
    if 'original_length' in metadata:
        _write_tlv_field(buffer, TLV_TYPE_ORIGINAL_LENGTH, 
                        metadata['original_length'].to_bytes(8, 'big'))
    
    # Was string flag
    if 'was_string' in metadata:
        _write_tlv_field(buffer, TLV_TYPE_WAS_STRING, 
                        bytes([1 if metadata['was_string'] else 0]))
    
    # PHE hash
    if 'phe_hash' in metadata:
        phe_hash = metadata['phe_hash']
        if isinstance(phe_hash, str):
            phe_hash = bytes.fromhex(phe_hash)
        _write_tlv_field(buffer, TLV_TYPE_PHE_HASH, phe_hash)
    
    # Metadata MAC
    if 'metadata_mac' in metadata:
        mac = metadata['metadata_mac']
        if isinstance(mac, str):
            mac = bytes.fromhex(mac)
        _write_tlv_field(buffer, TLV_TYPE_METADATA_MAC, mac)
    
    # Ephemeral seed
    if 'ephemeral_seed' in metadata:
        _write_tlv_field(buffer, TLV_TYPE_EPHEMERAL_SEED,
                        metadata['ephemeral_seed'].to_bytes(4, 'big'))
    
    # Decoy snapshots
    if 'decoy_snapshots' in metadata:
        for decoy in metadata['decoy_snapshots']:
            decoy_data = _serialize_cel_snapshot(decoy)
            _write_tlv_field(buffer, TLV_TYPE_DECOY_SNAPSHOT, decoy_data)
    
    return bytes(buffer)


def deserialize_metadata_tlv(data: bytes) -> Dict[str, Any]:
    """
    Deserialize TLV binary format to metadata dictionary
    
    Args:
        data: TLV-encoded bytes
        
    Returns:
        Metadata dictionary
    """
    metadata = {}
    decoy_snapshots = []
    vectors = []
    offset = 0
    
    while offset < len(data):
        if offset + 5 > len(data):  # Need at least type + length
            break
            
        tlv_type = data[offset]
        offset += 1
        
        length = int.from_bytes(data[offset:offset+4], 'big')
        offset += 4
        
        if offset + length > len(data):
            raise ValueError(f"Invalid TLV: length {length} exceeds remaining data")
        
        value_data = data[offset:offset+length]
        offset += length
        
        # Parse based on type
        if tlv_type == TLV_TYPE_METADATA_VERSION:
            metadata['metadata_version'] = int.from_bytes(value_data, 'big')
        
        elif tlv_type == TLV_TYPE_CEL_SNAPSHOT:
            metadata['cel_snapshot'] = _deserialize_cel_snapshot(value_data)
        
        elif tlv_type == TLV_TYPE_DIFFERENTIAL_DELTAS:
            metadata['differential_deltas'] = _deserialize_differential_deltas(value_data)
        
        elif tlv_type == TLV_TYPE_ORIGINAL_LENGTH:
            metadata['original_length'] = int.from_bytes(value_data, 'big')
        
        elif tlv_type == TLV_TYPE_WAS_STRING:
            metadata['was_string'] = bool(value_data[0])
        
        elif tlv_type == TLV_TYPE_PHE_HASH:
            metadata['phe_hash'] = value_data
        
        elif tlv_type == TLV_TYPE_METADATA_MAC:
            metadata['metadata_mac'] = value_data
        
        elif tlv_type == TLV_TYPE_EPHEMERAL_SEED:
            metadata['ephemeral_seed'] = int.from_bytes(value_data, 'big')
        
        elif tlv_type == TLV_TYPE_ENCRYPTED_METADATA:
            metadata['encrypted_metadata'] = value_data
        
        elif tlv_type == TLV_TYPE_DIFFERENTIAL_ENCODED:
            metadata['differential_encoded'] = bool(value_data[0])
        
        elif tlv_type == TLV_TYPE_OBFUSCATED:
            metadata['obfuscated'] = bool(value_data[0])
        
        elif tlv_type == TLV_TYPE_NUM_VECTORS:
            metadata['num_vectors'] = int.from_bytes(value_data, 'big')
        
        elif tlv_type == TLV_TYPE_VECTOR:
            # Recursively deserialize each vector
            vector = deserialize_metadata_tlv(value_data)
            vectors.append(vector)
        
        elif tlv_type == TLV_TYPE_DECOY_SNAPSHOT:
            decoy_snapshots.append(_deserialize_cel_snapshot(value_data))
    
    # Post-processing: combine cel_snapshot with differential_deltas if both present
    if 'cel_snapshot' in metadata and 'differential_deltas' in metadata:
        cel_snapshot = metadata['cel_snapshot']
        deltas = metadata.pop('differential_deltas')  # Remove from top level
        
        # Add differential info to cel_snapshot
        cel_snapshot['differential'] = True
        cel_snapshot['deltas'] = deltas
        cel_snapshot['num_deltas'] = len(deltas)
    
    if decoy_snapshots:
        metadata['decoy_snapshots'] = decoy_snapshots
    
    if vectors:
        metadata['vectors'] = vectors
    
    return metadata


def _write_tlv_field(buffer: bytearray, tlv_type: int, value: bytes) -> None:
    """Write single TLV field to buffer"""
    buffer.append(tlv_type)
    buffer.extend(len(value).to_bytes(4, 'big'))
    buffer.extend(value)


def _serialize_cel_snapshot_header(snapshot: Dict[str, Any]) -> bytes:
    """
    Serialize only CEL snapshot header (for differential snapshots)
    
    Format:
    - lattice_size: 2 bytes
    - depth: 1 byte
    - seed_fingerprint: 8 bytes
    - operation_count: 4 bytes
    - state_version: 4 bytes
    """
    buffer = bytearray()
    
    # Header only (no lattice)
    buffer.extend(snapshot.get('lattice_size', 256).to_bytes(2, 'big'))
    buffer.append(snapshot.get('depth', 8))
    buffer.extend(snapshot.get('seed_fingerprint', 0).to_bytes(8, 'big', signed=True))
    buffer.extend(snapshot.get('operation_count', 0).to_bytes(4, 'big'))
    buffer.extend(snapshot.get('state_version', 0).to_bytes(4, 'big'))
    
    return bytes(buffer)


def _serialize_cel_snapshot(snapshot: Dict[str, Any]) -> bytes:
    """
    Serialize CEL snapshot to compact binary format
    
    Format:
    - lattice_size: 2 bytes
    - depth: 1 byte
    - seed_fingerprint: 8 bytes
    - operation_count: 4 bytes
    - state_version: 4 bytes
    - lattice data: variable (compressed)
    """
    buffer = bytearray()
    
    # Header
    buffer.extend(snapshot.get('lattice_size', 256).to_bytes(2, 'big'))
    buffer.append(snapshot.get('depth', 8))
    buffer.extend(snapshot.get('seed_fingerprint', 0).to_bytes(8, 'big', signed=True))
    buffer.extend(snapshot.get('operation_count', 0).to_bytes(4, 'big'))
    buffer.extend(snapshot.get('state_version', 0).to_bytes(4, 'big'))
    
    # Lattice data (flatten and compress)
    if 'lattice' in snapshot:
        lattice = snapshot['lattice']
        if isinstance(lattice, np.ndarray):
            # Flatten and convert to bytes
            flat = lattice.flatten()
            # Use variable-length encoding for values
            compressed = _compress_int_array(flat)
            buffer.extend(compressed)
        else:
            # Already compressed or empty
            pass
    
    return bytes(buffer)


def _deserialize_cel_snapshot(data: bytes) -> Dict[str, Any]:
    """
    Deserialize CEL snapshot from binary format
    
    Args:
        data: Binary snapshot data
        
    Returns:
        CEL snapshot dictionary
    """
    if len(data) < 19:  # Minimum header size
        raise ValueError("Invalid CEL snapshot: too short")
    
    offset = 0
    
    # Parse header
    lattice_size = int.from_bytes(data[offset:offset+2], 'big')
    offset += 2
    
    depth = data[offset]
    offset += 1
    
    seed_fingerprint = int.from_bytes(data[offset:offset+8], 'big', signed=True)
    offset += 8
    
    operation_count = int.from_bytes(data[offset:offset+4], 'big')
    offset += 4
    
    state_version = int.from_bytes(data[offset:offset+4], 'big')
    offset += 4
    
    # Parse lattice data
    lattice = None
    if offset < len(data):
        remaining = data[offset:]
        flat = _decompress_int_array(remaining, depth * lattice_size * lattice_size)
        lattice = flat.reshape((depth, lattice_size, lattice_size))
    
    return {
        'lattice_size': lattice_size,
        'depth': depth,
        'seed_fingerprint': seed_fingerprint,
        'operation_count': operation_count,
        'state_version': state_version,
        'lattice': lattice
    }


def _serialize_differential_deltas(deltas: list) -> bytes:
    """
    Serialize differential deltas to binary format
    
    Format for each delta: [layer:1][row:2][col:2][delta:8 signed]
    """
    buffer = bytearray()
    
    # Write number of deltas
    buffer.extend(len(deltas).to_bytes(4, 'big'))
    
    # Write each delta
    for layer, row, col, delta in deltas:
        buffer.append(layer)
        buffer.extend(row.to_bytes(2, 'big'))
        buffer.extend(col.to_bytes(2, 'big'))
        buffer.extend(int(delta).to_bytes(8, 'big', signed=True))
    
    return bytes(buffer)


def _deserialize_differential_deltas(data: bytes) -> list:
    """Deserialize differential deltas from binary format"""
    if len(data) < 4:
        return []
    
    num_deltas = int.from_bytes(data[0:4], 'big')
    deltas = []
    offset = 4
    
    for _ in range(num_deltas):
        if offset + 13 > len(data):
            break
        
        layer = data[offset]
        offset += 1
        
        row = int.from_bytes(data[offset:offset+2], 'big')
        offset += 2
        
        col = int.from_bytes(data[offset:offset+2], 'big')
        offset += 2
        
        delta = int.from_bytes(data[offset:offset+8], 'big', signed=True)
        offset += 8
        
        deltas.append((layer, row, col, delta))
    
    return deltas


def _compress_int_array(arr: np.ndarray) -> bytes:
    """
    Compress integer array using variable-length encoding (varint)
    Uses zigzag encoding for signed integers + LEB128-style varint
    
    Compression strategies:
    1. Run-length encoding for consecutive zeros
    2. Varint encoding for small values
    3. Dictionary encoding for repeated patterns
    """
    buffer = bytearray()
    
    # Write array length
    buffer.extend(len(arr).to_bytes(4, 'big'))
    
    i = 0
    while i < len(arr):
        val = int(arr[i])
        
        # Check for run of zeros
        if val == 0:
            zero_count = 1
            while i + zero_count < len(arr) and arr[i + zero_count] == 0:
                zero_count += 1
            
            # If 3+ zeros, use RLE marker
            if zero_count >= 3:
                buffer.append(0xFF)  # RLE marker
                _write_varint(buffer, zero_count)
                i += zero_count
                continue
        
        # Use varint encoding for regular values
        _write_signed_varint(buffer, val)
        i += 1
    
    return bytes(buffer)


def _decompress_int_array(data: bytes, expected_length: int) -> np.ndarray:
    """
    Decompress integer array from variable-length encoding with RLE
    """
    if len(data) < 4:
        return np.zeros(expected_length, dtype=np.int64)
    
    arr_length = int.from_bytes(data[0:4], 'big')
    offset = 4
    
    arr = []
    while len(arr) < min(arr_length, expected_length) and offset < len(data):
        # Check for RLE marker
        if data[offset] == 0xFF:
            offset += 1
            zero_count, bytes_read = _read_varint(data, offset)
            offset += bytes_read
            arr.extend([0] * zero_count)
        else:
            # Read signed varint
            val, bytes_read = _read_signed_varint(data, offset)
            offset += bytes_read
            arr.append(val)
    
    # Pad if necessary
    while len(arr) < expected_length:
        arr.append(0)
    
    return np.array(arr[:expected_length], dtype=np.int64)


def _write_varint(buffer: bytearray, value: int) -> None:
    """
    Write unsigned integer in LEB128 varint format
    """
    while value >= 0x80:
        buffer.append((value & 0x7F) | 0x80)
        value >>= 7
    buffer.append(value & 0x7F)


def _read_varint(data: bytes, offset: int) -> tuple:
    """
    Read unsigned varint from data at offset
    Returns (value, bytes_read)
    """
    result = 0
    shift = 0
    bytes_read = 0
    
    while offset + bytes_read < len(data):
        byte = data[offset + bytes_read]
        bytes_read += 1
        
        result |= (byte & 0x7F) << shift
        if not (byte & 0x80):
            break
        shift += 7
    
    return result, bytes_read


def _write_signed_varint(buffer: bytearray, value: int) -> None:
    """
    Write signed integer using zigzag encoding + varint
    Zigzag maps signed to unsigned: 0,-1,1,-2,2... → 0,1,2,3,4...
    """
    zigzag = (value << 1) ^ (value >> 63) if value < 0 else (value << 1)
    _write_varint(buffer, zigzag)


def _read_signed_varint(data: bytes, offset: int) -> tuple:
    """
    Read signed varint using zigzag decoding
    Returns (value, bytes_read)
    """
    zigzag, bytes_read = _read_varint(data, offset)
    value = (zigzag >> 1) ^ -(zigzag & 1)
    return value, bytes_read


def detect_metadata_version(data: Union[bytes, str]) -> int:
    """
    Detect metadata format version
    
    Args:
        data: Metadata bytes or JSON string
        
    Returns:
        METADATA_VERSION_JSON (0x00) or METADATA_VERSION_TLV (0x01)
    """
    if isinstance(data, str):
        # String data is JSON (v0.1.x)
        return METADATA_VERSION_JSON
    
    if isinstance(data, bytes):
        # Check for TLV format signature
        if len(data) > 5 and data[0] == TLV_TYPE_METADATA_VERSION:
            return METADATA_VERSION_TLV
        # Try to decode as JSON
        try:
            data.decode('utf-8')
            return METADATA_VERSION_JSON
        except:
            # Assume TLV if not valid UTF-8
            return METADATA_VERSION_TLV
    
    return METADATA_VERSION_JSON
