"""
Parallel Token Generator with vLLM-style optimizations
Implements one-shot generation through iterative refinement
"""
import torch
import torch.nn.functional as F
from typing import Optional, List, Dict, Tuple
from dataclasses import dataclass
import math


@dataclass
class GenerationConfig:
    """Configuration for parallel generation"""
    max_new_tokens: int = 512
    temperature: float = 1.0
    top_k: int = 50
    top_p: float = 0.95
    repetition_penalty: float = 1.0
    num_refinement_steps: int = 5
    confidence_threshold: float = 0.9
    use_adaptive_steps: bool = True
    batch_size: int = 1
    num_parallel_tokens: int = 64  # For compatibility with config from inference module
    use_torch_compile: bool = False # Use torch.compile for faster inference (requires PyTorch 2.0+)
    
    def __post_init__(self):
        """Validate configuration values"""
        # Ensure temperature is positive
        if self.temperature <= 0:
            self.temperature = 1.0
        
        # Ensure top_k is non-negative
        if self.top_k < 0:
            self.top_k = 0
        
        # Clamp top_p to [0, 1]
        self.top_p = max(0.0, min(1.0, self.top_p))
        
        # Ensure repetition_penalty is positive
        if self.repetition_penalty <= 0:
            self.repetition_penalty = 1.0
        
        # Ensure num_refinement_steps is at least 1
        if self.num_refinement_steps < 1:
            self.num_refinement_steps = 1
        
        # Clamp confidence_threshold to [0, 1]
        self.confidence_threshold = max(0.0, min(1.0, self.confidence_threshold))


class PagedKVCache:
    """
    Paged KV cache for memory-efficient attention
    Inspired by vLLM's PagedAttention
    """
    def __init__(
        self,
        num_layers: int,
        num_heads: int,
        head_dim: int,
        block_size: int = 16,
        max_blocks: int = 2048,
        dtype: torch.dtype = torch.bfloat16,
        device: str = "cuda",
    ):
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.block_size = block_size
        self.max_blocks = max_blocks
        self.dtype = dtype
        self.device = device

        # Pre-allocate memory pool
        self.key_cache = torch.zeros(
            (num_layers, max_blocks, block_size, num_heads, head_dim),
            dtype=dtype,
            device=device,
        )
        self.value_cache = torch.zeros(
            (num_layers, max_blocks, block_size, num_heads, head_dim),
            dtype=dtype,
            device=device,
        )

        # Block allocation tracking
        self.free_blocks = list(range(max_blocks))
        self.allocated_blocks: Dict[int, List[int]] = {}  # sequence_id -> block_ids

    def allocate(self, sequence_id: int, num_tokens: int) -> List[int]:
        """Allocate blocks for a sequence"""
        num_blocks_needed = (num_tokens + self.block_size - 1) // self.block_size

        if len(self.free_blocks) < num_blocks_needed:
            raise RuntimeError("Out of KV cache memory")

        blocks = [self.free_blocks.pop() for _ in range(num_blocks_needed)]
        self.allocated_blocks[sequence_id] = blocks
        return blocks

    def free(self, sequence_id: int):
        """Free blocks for a sequence"""
        if sequence_id in self.allocated_blocks:
            self.free_blocks.extend(self.allocated_blocks[sequence_id])
            del self.allocated_blocks[sequence_id]

    def get_kv(self, sequence_id: int, layer_idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """Get KV cache for a sequence at a specific layer"""
        blocks = self.allocated_blocks.get(sequence_id, [])
        if not blocks:
            return None, None

        # Gather from blocks
        keys = self.key_cache[layer_idx, blocks]
        values = self.value_cache[layer_idx, blocks]

        return keys, values

    def update_kv(
        self,
        sequence_id: int,
        layer_idx: int,
        key: torch.Tensor,
        value: torch.Tensor,
        token_pos: int,
    ):
        """Update KV cache for a specific position"""
        blocks = self.allocated_blocks[sequence_id]
        block_idx = token_pos // self.block_size
        offset = token_pos % self.block_size

        self.key_cache[layer_idx, blocks[block_idx], offset] = key
        self.value_cache[layer_idx, blocks[block_idx], offset] = value


class ParallelGenerator:
    """
    Ultra-fast parallel token generator
    Combines masked diffusion, speculative decoding, and continuous batching
    """
    def __init__(
        self,
        model: torch.nn.Module,
        config: GenerationConfig,
        use_kv_cache: bool = True,
        use_cuda_graphs: bool = True,
    ):
        self.model = model
        self.config = config
        self.use_kv_cache = use_kv_cache
        self.use_cuda_graphs = use_cuda_graphs

        # Setup KV cache if enabled
        if use_kv_cache:
            self.kv_cache = PagedKVCache(
                num_layers=model.config.num_hidden_layers,
                num_heads=model.config.num_attention_heads,
                head_dim=model.config.hidden_size // model.config.num_attention_heads,
                device=next(model.parameters()).device,
            )

        # Compile model for faster inference
        if config.use_torch_compile and hasattr(torch, "compile"):
            self.model = torch.compile(self.model)

        if use_cuda_graphs:
            self._setup_cuda_graphs()

    def _setup_cuda_graphs(self):
        """Setup CUDA graphs for maximum throughput"""
        # Warmup to trigger compilation
        device = next(self.model.parameters()).device
        vocab_size = self.model.config.vocab_size
        # Generate dummy input with valid token range [0, vocab_size-1]
        dummy_input = torch.randint(0, min(100, vocab_size), (1, 64), device=device)
        dummy_timestep = torch.zeros(1, device=device)

        for _ in range(3):
            with torch.no_grad():
                _ = self.model(dummy_input, dummy_timestep)

        # Capture CUDA graph
        self.static_input = torch.zeros((self.config.batch_size, 512), dtype=torch.long, device=device)
        self.static_timestep = torch.zeros(self.config.batch_size, device=device)
        self.static_output = None

        self.graph = torch.cuda.CUDAGraph()

        with torch.cuda.graph(self.graph):
            self.static_output, _ = self.model(self.static_input, self.static_timestep)

    @torch.no_grad()
    def generate(
        self,
        prompt_tokens: torch.Tensor,
        max_new_tokens: Optional[int] = None,
        pixel_values: Optional[torch.Tensor] = None,
        **kwargs
    ) -> torch.Tensor:
        """
        Generate tokens in parallel

        Args:
            prompt_tokens: [batch, prompt_len] input tokens
            max_new_tokens: maximum number of tokens to generate
            pixel_values: [batch, channels, height, width] for multimodal generation
            **kwargs: additional arguments passed to the model

        Returns:
            generated_tokens: [batch, prompt_len + max_new_tokens]
        """
        batch_size, prompt_len = prompt_tokens.shape
        max_new_tokens = max_new_tokens or self.config.max_new_tokens
        device = prompt_tokens.device

        # Validate prompt tokens against vocab size
        if prompt_tokens.max() >= self.model.config.vocab_size:
            raise ValueError(
                f"Prompt contains token IDs greater than model vocabulary size "
                f"({self.model.config.vocab_size}). This will cause CUDA errors."
            )

        # Initialize full sequence with MASK tokens (using 0 as mask token)
        total_len = prompt_len + max_new_tokens
        mask_token_id = getattr(self.model, 'mask_token_id', 0)
        generated = torch.cat([
            prompt_tokens,
            torch.full(
                (batch_size, max_new_tokens),
                mask_token_id,
                dtype=torch.long,
                device=device
            )
        ], dim=1)

        # Track which positions are still masked
        is_masked = torch.zeros((batch_size, total_len), dtype=torch.bool, device=device)
        is_masked[:, prompt_len:] = True

        # Iterative parallel refinement
        num_steps = self.config.num_refinement_steps

        for step in range(num_steps):
            # Adaptive step calculation
            if self.config.use_adaptive_steps:
                current_step = self._get_adaptive_timestep(is_masked, step, num_steps)
            else:
                current_step = num_steps - step

            timestep = torch.full((batch_size,), current_step, device=device, dtype=torch.long)

            # Forward pass - predict all positions simultaneously
            if self.use_cuda_graphs and generated.shape[1] <= self.static_input.shape[1]:
                # Use CUDA graph for fixed-size inputs
                self.static_input[:batch_size, :total_len].copy_(generated)
                self.static_timestep[:batch_size].copy_(timestep)
                self.graph.replay()
                logits = self.static_output[:batch_size, :total_len]
                confidence = None
            else:
                # Pass pixel_values if available for multimodal generation
                model_kwargs = {'return_confidence': True}
                if pixel_values is not None:
                    model_kwargs['pixel_values'] = pixel_values
                
                output = self.model(generated, timestep, **model_kwargs)
                if isinstance(output, tuple):
                    logits, confidence = output
                else:
                    logits = output
                    confidence = None

            # Apply temperature
            if self.config.temperature > 0:
                logits = logits / self.config.temperature
            
            # Replace NaN and inf values with very negative numbers
            logits = torch.nan_to_num(logits, nan=-1e9, posinf=1e9, neginf=-1e9)

            # Apply top-k filtering
            if self.config.top_k > 0:
                logits = self._top_k_filtering(logits, self.config.top_k)

            # Apply top-p (nucleus) filtering
            if self.config.top_p < 1.0:
                logits = self._top_p_filtering(logits, self.config.top_p)

            # Apply repetition penalty
            if self.config.repetition_penalty != 1.0:
                logits = self._apply_repetition_penalty(logits, generated, self.config.repetition_penalty)

            # Sample tokens with safety checks
            probs = F.softmax(logits, dim=-1)
            
            # Replace any NaN or inf in probs
            probs = torch.nan_to_num(probs, nan=0.0, posinf=1.0, neginf=0.0)
            
            # Ensure probs sum to 1 (renormalize if needed)
            probs_sum = probs.sum(dim=-1, keepdim=True)
            probs = probs / (probs_sum + 1e-10)
            
            next_tokens = torch.multinomial(
                probs.view(-1, self.model.config.vocab_size),
                num_samples=1
            ).view(batch_size, total_len)
            
            # Clamp sampled tokens to valid range [0, vocab_size-1]
            next_tokens = torch.clamp(next_tokens, 0, self.model.config.vocab_size - 1)

            # Compute confidence if not returned by model
            if confidence is None:
                confidence = probs.max(dim=-1).values

            # Determine which positions to keep based on confidence
            keep_mask = (confidence > self.config.confidence_threshold) & is_masked

            # Update generated tokens
            generated = torch.where(keep_mask, next_tokens, generated)
            is_masked = is_masked & ~keep_mask

            # Early stopping if all tokens are decided
            if not is_masked.any():
                break

        # Final pass to fill any remaining masked positions
        if is_masked.any():
            timestep = torch.zeros(batch_size, device=device, dtype=torch.long)
            logits, _ = self.model(generated, timestep)
            final_tokens = torch.argmax(logits, dim=-1)
            # Clamp to valid range
            final_tokens = torch.clamp(final_tokens, 0, self.model.config.vocab_size - 1)
            generated = torch.where(is_masked, final_tokens, generated)

        return generated

    def _get_adaptive_timestep(
        self,
        is_masked: torch.Tensor,
        current_step: int,
        total_steps: int,
    ) -> int:
        """Adaptively determine timestep based on masking ratio"""
        mask_ratio = is_masked.float().mean()
        # Higher timestep for more masked tokens
        timestep = int(mask_ratio * total_steps)
        return max(1, timestep)

    def _top_k_filtering(self, logits: torch.Tensor, top_k: int) -> torch.Tensor:
        """Filter logits to keep only top-k tokens"""
        # Get the vocabulary size from the last dimension
        vocab_size = logits.shape[-1]
        
        # Clamp top_k to be at most vocab_size
        top_k = min(top_k, vocab_size)
        
        # If top_k is 0 or negative, return logits unchanged
        if top_k <= 0:
            return logits
        
        # Get top-k values and filter
        top_k_values = torch.topk(logits, top_k, dim=-1)[0]
        indices_to_remove = logits < top_k_values[..., -1, None]
        logits = logits.masked_fill(indices_to_remove, float('-inf'))
        return logits

    def _top_p_filtering(self, logits: torch.Tensor, top_p: float) -> torch.Tensor:
        """Filter logits using nucleus (top-p) sampling"""
        # If top_p is 1.0 or greater, no filtering needed
        if top_p >= 1.0:
            return logits
        
        # Sort logits in descending order
        sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
        
        # Compute cumulative probabilities
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above threshold
        # Keep at least one token (the first one)
        sorted_indices_to_remove = cumulative_probs > top_p
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = False  # Always keep the top token

        # Scatter back to original indexing
        indices_to_remove = sorted_indices_to_remove.scatter(
            -1, sorted_indices, sorted_indices_to_remove
        )
        logits = logits.masked_fill(indices_to_remove, float('-inf'))
        return logits

    def _apply_repetition_penalty(
        self,
        logits: torch.Tensor,
        generated: torch.Tensor,
        penalty: float,
    ) -> torch.Tensor:
        """Apply repetition penalty to logits"""
        mask_token_id = getattr(self.model, 'mask_token_id', 0)
        vocab_size = self.model.config.vocab_size
        
        for i in range(generated.shape[0]):
            for token in set(generated[i].tolist()):
                # Skip mask token and ensure token is in valid range
                if token != mask_token_id and 0 <= token < vocab_size:
                    # If score < 0 then repetition penalty has to be multiplied
                    # to reduce the score, otherwise it has to be divided
                    if logits[i, :, token] < 0:
                        logits[i, :, token] *= penalty
                    else:
                        logits[i, :, token] /= penalty
        return logits

    @torch.no_grad()
    def generate_batch(
        self,
        prompts: List[torch.Tensor],
        max_new_tokens: Optional[List[int]] = None,
    ) -> List[torch.Tensor]:
        """
        Generate for multiple prompts with continuous batching
        Different prompts can have different lengths
        """
        if max_new_tokens is None:
            max_new_tokens = [self.config.max_new_tokens] * len(prompts)

        # Pad prompts to same length for batching
        max_prompt_len = max(p.shape[0] for p in prompts)
        batch_size = len(prompts)
        device = prompts[0].device
        
        mask_token_id = getattr(self.model, 'mask_token_id', 0)
        padded_prompts = torch.full(
            (batch_size, max_prompt_len),
            mask_token_id,
            dtype=torch.long,
            device=device,
        )

        for i, prompt in enumerate(prompts):
            padded_prompts[i, :len(prompt)] = prompt

        # Generate
        generated = self.generate(padded_prompts, max(max_new_tokens))

        # Unpad and return individual sequences
        results = []
        for i, prompt in enumerate(prompts):
            result = generated[i, :len(prompt) + max_new_tokens[i]]
            results.append(result)

        return results


def create_generator(
    model: torch.nn.Module,
    config: Optional[GenerationConfig] = None,
    use_kv_cache: bool = True,
    use_cuda_graphs: bool = True,
) -> ParallelGenerator:
    """Factory function to create parallel generator"""
    if config is None:
        config = GenerationConfig()

    return ParallelGenerator(
        model=model,
        config=config,
        use_kv_cache=use_kv_cache,
        use_cuda_graphs=use_cuda_graphs,
    )
