"""
State Management Module
Handles meta-state persistence, session reconstruction, and reproducibility

Stores compact persistence vectors (seed hashes, CEL snapshots,
meta-state signatures) and regenerates states deterministically
from compact data.

Key principles:
- State data never includes direct key material
- Version control for backward-compatible regeneration
- Compact storage format
- Deterministic reconstruction
"""

import json
import numpy as np
from typing import Dict, Any, Optional, List
from datetime import datetime
import base64


class StateManager:
    """
    STATE - Persistence and reconstruction manager
    
    Handles saving and loading of complete STC context states,
    enabling deterministic reproduction without exposing secrets.
    """
    
    def __init__(self):
        """Initialize State Manager"""
        self.state_version = "0.1.0"
        self.current_context: Optional[Dict[str, Any]] = None
        
    def save(self, context: Dict[str, Any]) -> Dict[str, Any]:
        """
        Store reproducible state vector
        
        Per STATE contract: STATE.save(context) → store reproducible state vector
        
        Args:
            context: Complete context dictionary containing:
                    - 'cel': CEL instance or snapshot
                    - 'phe': PHE instance (optional)
                    - 'pcf': PCF instance (optional)
                    - 'seed': Original seed
                    - 'metadata': Optional metadata
                    
        Returns:
            Compact persistence vector
        """
        persistence_vector = {
            'version': self.state_version,
            'timestamp': datetime.now().isoformat(),
            'cel_state': self._extract_cel_state(context.get('cel')),
            'pcf_state': self._extract_pcf_state(context.get('pcf')),
            'seed_fingerprint': self._compute_seed_fingerprint(context.get('seed')),
            'metadata': context.get('metadata', {}),
        }
        
        # Store current context
        self.current_context = persistence_vector
        
        return persistence_vector
    
    def load(self, vector: Dict[str, Any]) -> Dict[str, Any]:
        """
        Reconstruct context from persistence vector
        
        Per STATE contract: STATE.load(vector) → reconstruct context
        
        Args:
            vector: Persistence vector from save()
            
        Returns:
            Reconstructed context dictionary
        """
        # Validate version compatibility
        if not self._is_compatible_version(vector.get('version')):
            raise ValueError(f"Incompatible state version: {vector.get('version')}")
        
        # Reconstruct context
        context = {
            'version': vector['version'],
            'timestamp': vector.get('timestamp'),
            'cel_state': vector.get('cel_state'),
            'pcf_state': vector.get('pcf_state'),
            'seed_fingerprint': vector.get('seed_fingerprint'),
            'metadata': vector.get('metadata', {}),
        }
        
        self.current_context = context
        
        return context
    
    def sync(
        self,
        phe_instance: Any = None,
        cel_instance: Any = None,
        pcf_instance: Any = None
    ) -> Dict[str, Any]:
        """
        Synchronize all modules and create unified state
        
        Per STATE contract: STATE.sync(PHE, CEL, PCF) → synchronize all modules
        
        Args:
            phe_instance: PHE instance (optional)
            cel_instance: CEL instance (optional)
            pcf_instance: PCF instance (optional)
            
        Returns:
            Synchronized state dictionary
        """
        sync_state = {
            'timestamp': datetime.now().isoformat(),
            'modules': {}
        }
        
        # Sync CEL
        if cel_instance is not None:
            if hasattr(cel_instance, 'snapshot'):
                sync_state['modules']['cel'] = cel_instance.snapshot()
            else:
                sync_state['modules']['cel'] = self._extract_cel_state(cel_instance)
        
        # Sync PHE
        if phe_instance is not None:
            if hasattr(phe_instance, 'trace'):
                sync_state['modules']['phe'] = phe_instance.trace()
            else:
                sync_state['modules']['phe'] = {}
        
        # Sync PCF
        if pcf_instance is not None:
            if hasattr(pcf_instance, 'export_state'):
                sync_state['modules']['pcf'] = pcf_instance.export_state()
            else:
                sync_state['modules']['pcf'] = self._extract_pcf_state(pcf_instance)
        
        return sync_state
    
    def _extract_cel_state(self, cel: Any) -> Dict[str, Any]:
        """
        Extract CEL state for persistence
        
        Args:
            cel: CEL instance or snapshot
            
        Returns:
            CEL state dictionary
        """
        if cel is None:
            return {}
        
        # If already a snapshot dict, return it
        if isinstance(cel, dict):
            # Convert numpy arrays to lists for JSON serialization
            state = cel.copy()
            if 'lattice' in state and isinstance(state['lattice'], np.ndarray):
                state['lattice'] = state['lattice'].tolist()
            return state
        
        # If CEL instance, get snapshot
        if hasattr(cel, 'snapshot'):
            snapshot = cel.snapshot()
            if 'lattice' in snapshot and isinstance(snapshot['lattice'], np.ndarray):
                snapshot['lattice'] = snapshot['lattice'].tolist()
            return snapshot
        
        return {}
    
    def _extract_pcf_state(self, pcf: Any) -> Dict[str, Any]:
        """
        Extract PCF state for persistence
        
        Args:
            pcf: PCF instance
            
        Returns:
            PCF state dictionary
        """
        if pcf is None:
            return {}
        
        # If already a state dict, return it
        if isinstance(pcf, dict):
            return pcf.copy()
        
        # If PCF instance, export state
        if hasattr(pcf, 'export_state'):
            return pcf.export_state()
        
        return {}
    
    def _compute_seed_fingerprint(self, seed: Any) -> Optional[int]:
        """
        Compute fingerprint of seed without storing seed itself
        
        Args:
            seed: Seed value
            
        Returns:
            Seed fingerprint
        """
        if seed is None:
            return None
        
        from utils.math_primitives import data_fingerprint_entropy
        
        if isinstance(seed, str):
            seed_bytes = seed.encode('utf-8')
        elif isinstance(seed, bytes):
            seed_bytes = seed
        elif isinstance(seed, int):
            seed_bytes = seed.to_bytes((seed.bit_length() + 7) // 8, 'big')
        else:
            seed_bytes = str(seed).encode('utf-8')
        
        return data_fingerprint_entropy(seed_bytes)
    
    def _is_compatible_version(self, version: Optional[str]) -> bool:
        """
        Check if state version is compatible
        
        Args:
            version: State version string
            
        Returns:
            True if compatible
        """
        if version is None:
            return False
        
        # Simple version check (major.minor.patch)
        # Compatible if major version matches
        try:
            current_major = int(self.state_version.split('.')[0])
            state_major = int(version.split('.')[0])
            return current_major == state_major
        except (ValueError, IndexError):
            return False
    
    def serialize(self, state: Dict[str, Any]) -> str:
        """
        Serialize state to JSON string
        
        Args:
            state: State dictionary
            
        Returns:
            JSON string
        """
        # Custom encoder for numpy arrays and other types
        def custom_encoder(obj):
            if isinstance(obj, np.ndarray):
                return obj.tolist()
            elif isinstance(obj, np.integer):
                return int(obj)
            elif isinstance(obj, np.floating):
                return float(obj)
            elif isinstance(obj, bytes):
                return base64.b64encode(obj).decode('utf-8')
            raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
        
        return json.dumps(state, default=custom_encoder, indent=2)
    
    def deserialize(self, json_string: str) -> Dict[str, Any]:
        """
        Deserialize state from JSON string
        
        Args:
            json_string: JSON string
            
        Returns:
            State dictionary
        """
        state = json.loads(json_string)
        
        # Reconstruct numpy arrays where needed
        if 'cel_state' in state and 'lattice' in state['cel_state']:
            if isinstance(state['cel_state']['lattice'], list):
                state['cel_state']['lattice'] = np.array(state['cel_state']['lattice'])
        
        return state
    
    def save_to_file(self, state: Dict[str, Any], filepath: str) -> None:
        """
        Save state to file
        
        Args:
            state: State dictionary
            filepath: File path for saving
        """
        json_str = self.serialize(state)
        with open(filepath, 'w') as f:
            f.write(json_str)
    
    def load_from_file(self, filepath: str) -> Dict[str, Any]:
        """
        Load state from file
        
        Args:
            filepath: File path to load from
            
        Returns:
            State dictionary
        """
        with open(filepath, 'r') as f:
            json_str = f.read()
        
        return self.deserialize(json_str)
    
    def validate_state(self, state: Dict[str, Any]) -> bool:
        """
        Validate state structure
        
        Args:
            state: State dictionary to validate
            
        Returns:
            True if valid
        """
        # Check required fields
        required_fields = ['version']
        for field in required_fields:
            if field not in state:
                return False
        
        # Check version compatibility
        if not self._is_compatible_version(state['version']):
            return False
        
        return True
    
    def get_state_summary(self, state: Optional[Dict[str, Any]] = None) -> str:
        """
        Get human-readable state summary
        
        Args:
            state: Optional state dictionary (uses current if None)
            
        Returns:
            Summary string
        """
        if state is None:
            state = self.current_context
        
        if state is None:
            return "No state available"
        
        lines = [
            "=== State Summary ===",
            f"Version: {state.get('version', 'Unknown')}",
            f"Timestamp: {state.get('timestamp', 'Unknown')}",
            "",
            "Components:",
        ]
        
        # CEL state
        if 'cel_state' in state and state['cel_state']:
            cel = state['cel_state']
            lines.append(f"  CEL:")
            lines.append(f"    Lattice Size: {cel.get('lattice_size', 'N/A')}")
            lines.append(f"    Depth: {cel.get('depth', 'N/A')}")
            lines.append(f"    Operation Count: {cel.get('operation_count', 'N/A')}")
            lines.append(f"    State Version: {cel.get('state_version', 'N/A')}")
        
        # PCF state
        if 'pcf_state' in state and state['pcf_state']:
            pcf = state['pcf_state']
            lines.append(f"  PCF:")
            lines.append(f"    Morph Version: {pcf.get('morph_version', 'N/A')}")
            lines.append(f"    Operation Count: {pcf.get('operation_count', 'N/A')}")
        
        # Metadata
        if 'metadata' in state and state['metadata']:
            lines.append("  Metadata:")
            for key, value in state['metadata'].items():
                lines.append(f"    {key}: {value}")
        
        return "\n".join(lines)


def create_state_manager() -> StateManager:
    """
    Create StateManager instance
    
    Returns:
        StateManager instance
    """
    return StateManager()


def save_context(
    cel_instance: Any = None,
    pcf_instance: Any = None,
    seed: Any = None,
    metadata: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
    """
    Convenience function to save complete context
    
    Args:
        cel_instance: CEL instance
        pcf_instance: PCF instance
        seed: Original seed
        metadata: Optional metadata
        
    Returns:
        Persistence vector
    """
    manager = create_state_manager()
    
    context = {
        'cel': cel_instance,
        'pcf': pcf_instance,
        'seed': seed,
        'metadata': metadata or {}
    }
    
    return manager.save(context)
