"""Statistics tracking for image prompt optimization iterations."""

import csv
from pathlib import Path
from datetime import datetime
from typing import Dict, List

from dailystories_generator.image_evaluation import CoverImageEvaluation, PageImageEvaluation


class ImageStatisticsTracker:
    """Track image evaluation scores across optimization iterations."""
    
    # Category names in order for CSV headers
    COVER_CATEGORIES = [
        "cover_child_resemblance",
        "cover_story_capture",
    ]
    
    PAGE_CATEGORIES = [
        "avg_page_child_resemblance",
        "avg_page_content_accuracy",
        "avg_page_style_consistency",
    ]
    
    def __init__(self, csv_path: Path):
        """
        Initialize the statistics tracker.
        
        Args:
            csv_path: Path to the statistics CSV file
        """
        self.csv_path = csv_path
        self._ensure_csv_exists()
    
    def _ensure_csv_exists(self) -> None:
        """Create the CSV file with headers if it doesn't exist."""
        if not self.csv_path.exists():
            with open(self.csv_path, 'w', newline='', encoding='utf-8') as f:
                writer = csv.writer(f)
                headers = (
                    ['timestamp', 'iteration', 'prompt_version'] 
                    + self.COVER_CATEGORIES 
                    + self.PAGE_CATEGORIES 
                    + ['overall_average']
                )
                writer.writerow(headers)
    
    def log_iteration(
        self,
        iteration: int,
        prompt_version: int,
        cover_evaluation: CoverImageEvaluation,
        page_evaluations: List[PageImageEvaluation],
    ) -> None:
        """
        Log an iteration's evaluation scores to the CSV.
        
        Args:
            iteration: The iteration number (1-based)
            prompt_version: The version number of the prompt
            cover_evaluation: The cover image evaluation results
            page_evaluations: List of page image evaluation results
        """
        timestamp = datetime.now().isoformat()
        
        # Get cover scores
        cover_scores = cover_evaluation.get_scores_dict()
        
        # Calculate average scores across all pages
        avg_page_scores = {
            "child_resemblance": 0.0,
            "page_content_accuracy": 0.0,
            "style_consistency": 0.0,
        }
        
        if page_evaluations:
            for page_eval in page_evaluations:
                page_scores = page_eval.get_scores_dict()
                avg_page_scores["child_resemblance"] += page_scores["child_resemblance"]
                avg_page_scores["page_content_accuracy"] += page_scores["page_content_accuracy"]
                avg_page_scores["style_consistency"] += page_scores["style_consistency"]
            
            # Calculate averages
            num_pages = len(page_evaluations)
            avg_page_scores = {
                k: v / num_pages for k, v in avg_page_scores.items()
            }
        
        # Calculate overall average across all categories
        all_scores = list(cover_scores.values()) + list(avg_page_scores.values())
        overall_average = sum(all_scores) / len(all_scores) if all_scores else 0.0
        
        # Build row data
        row = [timestamp, iteration, prompt_version]
        
        # Add cover scores
        row.append(cover_scores["child_resemblance"])
        row.append(cover_scores["story_capture"])
        
        # Add average page scores
        row.append(f"{avg_page_scores['child_resemblance']:.2f}")
        row.append(f"{avg_page_scores['page_content_accuracy']:.2f}")
        row.append(f"{avg_page_scores['style_consistency']:.2f}")
        
        # Add overall average
        row.append(f"{overall_average:.2f}")
        
        # Write to CSV
        with open(self.csv_path, 'a', newline='', encoding='utf-8') as f:
            writer = csv.writer(f)
            writer.writerow(row)
    
    def get_latest_scores(self) -> Dict[str, float] | None:
        """
        Get the latest scores from the CSV.
        
        Returns:
            Dictionary of category scores, or None if no data exists
        """
        if not self.csv_path.exists():
            return None
        
        with open(self.csv_path, 'r', encoding='utf-8') as f:
            reader = csv.DictReader(f)
            rows = list(reader)
            
            if not rows:
                return None
            
            # Get the last row
            last_row = rows[-1]
            
            # Extract scores
            scores = {}
            for category in self.COVER_CATEGORIES + self.PAGE_CATEGORIES:
                scores[category] = float(last_row[category])
            
            return scores

