#!/usr/bin/env python3
"""
Prompt optimization script for iterative improvement based on evaluation feedback.

Usage:
    # Set your API key
    export GOOGLE_API_KEY="your-gemini-api-key-here"
    
    # Run 50 iterations optimizing outline prompt
    python optimize_prompts.py --mode outline --iterations 50 --child-name Ludwig --child-age 6
    
    # Run 20 iterations optimizing page prompt
    python optimize_prompts.py --mode pages --iterations 20 --child-name Emma --child-age 7
    
    # Run 10 iterations optimizing image prompts
    python optimize_prompts.py --mode images --iterations 10 --child-name Ludwig --child-age 6 --reference-image story/original.png
"""

import asyncio
import argparse
import os
import sys
from pathlib import Path
from datetime import datetime

# Add package to path for local testing
package_path = Path(__file__).parent / "src"
sys.path.insert(0, str(package_path))

from dailystories_generator import (
    StoryGenerator,
    GenerationRequest,
    Update,
    UpdateType,
)
from dailystories_generator.gemini_client import GeminiClient
from dailystories_generator.evaluation import evaluate_content
from dailystories_generator.prompt_improver import suggest_improvements
from dailystories_generator.statistics_tracker import StatisticsTracker
from dailystories_generator.image_evaluation import evaluate_cover_image, evaluate_page_image
from dailystories_generator.image_prompt_improver import suggest_cover_prompt_improvements, suggest_page_prompt_improvements
from dailystories_generator.image_statistics_tracker import ImageStatisticsTracker
from dailystories_generator.optimization_utils import Colors, load_reference_image


async def generate_story_content(
    generator: StoryGenerator,
    request: GenerationRequest,
    mode: str,
):
    """
    Generate story content based on mode.
    
    Args:
        generator: StoryGenerator instance
        request: Generation request
        mode: 'outline', 'pages', or 'images'
    
    Returns:
        For outline/pages: Generated content as string
        For images: StoryArtifact with images
    """
    async def silent_update(update: Update) -> None:
        """Silent update callback."""
        pass
    
    if mode == "outline":
        # Generate only outline
        request.outline_only = True
        story = await generator.generate(request, on_update=silent_update)
        return story.outline
    elif mode == "pages":
        # Generate full story (text only)
        request.outline_only = False
        request.generate_images = False  # Skip images for optimization
        story = await generator.generate(request, on_update=silent_update)
        
        # Combine all pages into one text
        pages_text = "\n\n".join([
            f"Page {page.page_number}:\n{page.text_content}"
            for page in story.pages
        ])
        return pages_text
    else:  # images mode
        # Generate full story with images
        request.outline_only = False
        request.generate_images = True
        story = await generator.generate(request, on_update=silent_update)
        return story


async def optimize_prompts(
    mode: str,
    iterations: int,
    child_name: str,
    child_age: int,
    title: str,
    summary: str,
    num_pages: int,
    language: str,
    illustration_style: str = "colorful watercolor illustration",
    reference_image_path: Path = None,
) -> None:
    """
    Main optimization loop.
    
    Args:
        mode: 'outline', 'pages', or 'images'
        iterations: Number of optimization iterations
        child_name: Child's name for story
        child_age: Child's age
        title: Story title
        summary: Story summary/theme
        num_pages: Number of pages
        language: Story language
        illustration_style: Illustration style (for images mode)
        reference_image_path: Path to reference image (required for images mode)
    """
    # Get API key
    api_key = os.environ.get("GOOGLE_API_KEY")
    if not api_key:
        print(f"{Colors.FAIL}❌ Error: GOOGLE_API_KEY environment variable not set{Colors.ENDC}")
        print(f"{Colors.WARNING}Set it with: export GOOGLE_API_KEY='your-key-here'{Colors.ENDC}")
        return
    
    # Initialize components
    print(f"\n{Colors.HEADER}{'='*80}{Colors.ENDC}")
    print(f"{Colors.HEADER}🚀 Prompt Optimization System{Colors.ENDC}")
    print(f"{Colors.HEADER}{'='*80}{Colors.ENDC}\n")
    
    print(f"{Colors.OKBLUE}Mode: {mode}{Colors.ENDC}")
    print(f"{Colors.OKBLUE}Iterations: {iterations}{Colors.ENDC}")
    print(f"{Colors.OKBLUE}Test story: {title} (age {child_age}){Colors.ENDC}\n")
    
    generator = StoryGenerator(gemini_api_key=api_key)
    gemini_client = GeminiClient(api_key=api_key)
    
    # Setup paths
    package_root = Path(__file__).parent
    src_path = package_root / "src" / "dailystories_generator"
    
    if mode == "images":
        # Image mode uses two system instruction files
        cover_prompt_template_path = src_path / "prompt_templates" / "cover_image_system_instruction.txt"
        page_prompt_template_path = src_path / "prompt_templates" / "page_image_system_instruction.txt"
        statistics_path = package_root / "image_statistics.csv"
        log_path = package_root / "image_optimization.log"
    else:
        # Text modes use single prompt template
        prompt_template_path = src_path / "prompt_templates" / (
            "story_outline_prompt.txt" if mode == "outline" else "story_page_prompt.txt"
        )
        statistics_path = package_root / "statistics.csv"
        log_path = package_root / "optimization.log"
    
    # Initialize statistics tracker
    if mode == "images":
        tracker = ImageStatisticsTracker(statistics_path)
    else:
        tracker = StatisticsTracker(statistics_path)
    
    # Create log file for this optimization run (overwrites previous)
    log_file = open(log_path, 'w', encoding='utf-8')
    
    def log(message: str) -> None:
        """Write to both console and log file."""
        print(message)
        log_file.write(message + '\n')
        log_file.flush()
    
    log(f"{'='*80}\n")
    log(f"PROMPT OPTIMIZATION RUN\n")
    log(f"Mode: {mode}\n")
    log(f"Iterations: {iterations}\n")
    log(f"Test Story: {title} (age {child_age}, language: {language})\n")
    log(f"Child Name: {child_name}\n")
    log(f"Summary: {summary}\n")
    log(f"Number of Pages: {num_pages}\n")
    if mode == "images":
        log(f"Illustration Style: {illustration_style}\n")
        log(f"Reference Image: {reference_image_path}\n")
    log(f"Started: {datetime.now().isoformat()}\n")
    log(f"{'='*80}\n\n")
    
    print(f"{Colors.OKGREEN}✓ Initialized components{Colors.ENDC}")
    if mode == "images":
        print(f"{Colors.OKGREEN}✓ Cover prompt template: {cover_prompt_template_path.name}{Colors.ENDC}")
        print(f"{Colors.OKGREEN}✓ Page prompt template: {page_prompt_template_path.name}{Colors.ENDC}")
    else:
        print(f"{Colors.OKGREEN}✓ Prompt template: {prompt_template_path.name}{Colors.ENDC}")
    print(f"{Colors.OKGREEN}✓ Statistics: {statistics_path.name}{Colors.ENDC}")
    print(f"{Colors.OKGREEN}✓ Log file: {log_path.name}{Colors.ENDC}\n")
    
    # Load reference image for images mode
    reference_image = None
    reference_image_bytes = None
    if mode == "images":
        if not reference_image_path:
            print(f"{Colors.FAIL}❌ Error: --reference-image required for images mode{Colors.ENDC}")
            log_file.close()
            return
        
        try:
            reference_image, reference_image_bytes = load_reference_image(reference_image_path)
            print(f"{Colors.OKGREEN}✓ Loaded reference image{Colors.ENDC}\n")
        except Exception as e:
            print(f"{Colors.FAIL}❌ Failed to load reference image: {e}{Colors.ENDC}")
            log(f"❌ Failed to load reference image: {e}\n")
            log_file.close()
            return
    
    # Load initial prompt(s)
    if mode == "images":
        current_cover_prompt = cover_prompt_template_path.read_text(encoding="utf-8")
        current_page_prompt = page_prompt_template_path.read_text(encoding="utf-8")
        log(f"INITIAL COVER SYSTEM INSTRUCTION:\n{current_cover_prompt}\n")
        log(f"{'='*80}\n\n")
        log(f"INITIAL PAGE SYSTEM INSTRUCTION:\n{current_page_prompt}\n")
        log(f"{'='*80}\n\n")
    else:
        current_prompt = prompt_template_path.read_text(encoding="utf-8")
        log(f"INITIAL PROMPT:\n{current_prompt}\n")
        log(f"{'='*80}\n\n")
    
    # Create test request
    request = GenerationRequest(
        title=title,
        summary=summary,
        num_pages=num_pages,
        child_name=child_name,
        child_age=child_age,
        language=language,
        illustration_style=illustration_style,
        generate_images=(mode == "images"),
        reference_images=[reference_image] if reference_image else [],
    )
    
    # Main optimization loop
    for iteration in range(1, iterations + 1):
        log(f"\n{'='*80}\n")
        log(f"ITERATION {iteration}/{iterations}\n")
        log(f"{'='*80}\n\n")
        
        print(f"{Colors.BOLD}{Colors.HEADER}{'='*80}{Colors.ENDC}")
        print(f"{Colors.BOLD}{Colors.HEADER}Iteration {iteration}/{iterations}{Colors.ENDC}")
        print(f"{Colors.BOLD}{Colors.HEADER}{'='*80}{Colors.ENDC}\n")
        
        # Step 1: Generate content
        if mode == "images":
            log(f"STEP 1: GENERATING STORY WITH IMAGES\n")
            print(f"{Colors.OKCYAN}⏳ Step 1/5: Generating story with images...{Colors.ENDC}")
        else:
            log(f"STEP 1: GENERATING {mode.upper()}\n")
            print(f"{Colors.OKCYAN}⏳ Step 1/4: Generating {mode}...{Colors.ENDC}")
        
        try:
            content = await generate_story_content(generator, request, mode)
            
            if mode == "images":
                if not content.cover_image_data:
                    raise ValueError("No cover image generated")
                if not content.pages or not all(p.image_data for p in content.pages):
                    raise ValueError("Not all page images generated")
                
                # Save iteration output to disk
                iteration_dir = package_root / "optimization_output" / f"iteration{iteration}"
                iteration_dir.mkdir(parents=True, exist_ok=True)
                
                # Save story.md
                story_md = f"# {request.title}\n\n## Outline\n\n{content.outline}\n\n## Pages\n\n"
                for page in content.pages:
                    story_md += f"### Page {page.page_number}\n\n{page.text_content}\n\n"
                (iteration_dir / "story.md").write_text(story_md, encoding="utf-8")
                
                # Save cover image
                (iteration_dir / "cover.png").write_bytes(content.cover_image_data)
                
                # Save page images
                for page in content.pages:
                    (iteration_dir / f"page{page.page_number}.png").write_bytes(page.image_data)
                
                log(f"✓ Generated story with {len(content.pages)} pages and cover\n")
                log(f"✓ Saved to {iteration_dir}\n")
                print(f"{Colors.OKGREEN}✓ Generated story with {len(content.pages)} pages and cover{Colors.ENDC}")
                print(f"{Colors.OKGREEN}✓ Saved to {iteration_dir}{Colors.ENDC}\n")
            else:
                log(f"✓ Generated {len(content)} characters\n")
                log(f"GENERATED CONTENT:\n{content}\n")
                print(f"{Colors.OKGREEN}✓ Generated {len(content)} characters{Colors.ENDC}\n")
            
            log(f"{'-'*80}\n\n")
        except Exception as e:
            log(f"❌ Generation failed: {e}\n\n")
            print(f"{Colors.FAIL}❌ Generation failed: {e}{Colors.ENDC}\n")
            continue
        
        # Step 2: Evaluate content
        if mode == "images":
            # Image mode: evaluate cover and pages separately
            log(f"STEP 2: EVALUATING COVER IMAGE\n")
            print(f"{Colors.OKCYAN}⏳ Step 2/5: Evaluating images...{Colors.ENDC}")
            
            try:
                # Evaluate cover
                cover_eval = await evaluate_cover_image(
                    reference_image_bytes,
                    content.cover_image_data,
                    content.outline,
                    gemini_client,
                )
                
                cover_avg = cover_eval.get_average_score()
                log(f"✓ Cover evaluation complete - Average: {cover_avg:.2f}/5\n\n")
                log(f"COVER EVALUATION RESULTS:\n")
                
                for category, score in cover_eval.get_scores_dict().items():
                    cat_score = getattr(cover_eval, category)
                    log(f"  {category}: {score}/5")
                    log(f"    Explanation: {cat_score.explanation}")
                    if score <= 3 and cat_score.improvement_suggestion:
                        log(f"    Improvement: {cat_score.improvement_suggestion}")
                    log(f"\n")
                    color = Colors.OKGREEN if score >= 4 else (Colors.WARNING if score >= 3 else Colors.FAIL)
                    print(f"  Cover {color}{category}: {score}/5{Colors.ENDC}")
                
                # Evaluate pages
                log(f"\n STEP 3: EVALUATING PAGE IMAGES\n")
                page_evals = []
                for i, page in enumerate(content.pages):
                    previous_images = [p.image_data for p in content.pages[:i]]
                    
                    page_eval = await evaluate_page_image(
                        reference_image_bytes,
                        page.image_data,
                        page.text_content,
                        previous_images,
                        gemini_client,
                    )
                    page_evals.append(page_eval)
                    
                    page_avg = page_eval.get_average_score()
                    log(f"Page {page.page_number} - Average: {page_avg:.2f}/5\n")
                    
                    for category, score in page_eval.get_scores_dict().items():
                        cat_score = getattr(page_eval, category)
                        log(f"  {category}: {score}/5")
                        log(f"    Explanation: {cat_score.explanation}")
                        if score <= 3 and cat_score.improvement_suggestion:
                            log(f"    Improvement: {cat_score.improvement_suggestion}")
                        log(f"\n")
                        color = Colors.OKGREEN if score >= 4 else (Colors.WARNING if score >= 3 else Colors.FAIL)
                        print(f"  Page {page.page_number} {color}{category}: {score}/5{Colors.ENDC}")
                
                overall_page_avg = sum(e.get_average_score() for e in page_evals) / len(page_evals)
                overall_avg = (cover_avg + overall_page_avg) / 2
                
                log(f"Overall Page Average: {overall_page_avg:.2f}/5\n")
                log(f"Overall Average: {overall_avg:.2f}/5\n")
                log(f"{'-'*80}\n\n")
                print(f"{Colors.OKGREEN}✓ Image evaluation complete - Average: {overall_avg:.2f}/5{Colors.ENDC}\n")
                
            except Exception as e:
                log(f"❌ Image evaluation failed: {e}\n\n")
                print(f"{Colors.FAIL}❌ Image evaluation failed: {e}{Colors.ENDC}\n")
                continue
        else:
            # Text mode: evaluate content
            log(f"STEP 2: EVALUATING CONTENT\n")
            print(f"{Colors.OKCYAN}⏳ Step 2/4: Evaluating content...{Colors.ENDC}")
            try:
                evaluation = await evaluate_content(content, mode, child_age, gemini_client)
                avg_score = evaluation.get_average_score()
                log(f"✓ Evaluation complete - Average: {avg_score:.2f}/5\n\n")
                log(f"EVALUATION RESULTS:\n")
                
                scores = evaluation.get_scores_dict()
                for category, score in scores.items():
                    category_score = getattr(evaluation, category)
                    log(f"  {category}: {score}/5")
                    log(f"    Explanation: {category_score.explanation}")
                    if score <= 3 and category_score.improvement_suggestion:
                        log(f"    Improvement Suggestion: {category_score.improvement_suggestion}")
                    log(f"\n")
                    color = Colors.OKGREEN if score >= 4 else (Colors.WARNING if score >= 3 else Colors.FAIL)
                    print(f"  {color}{category}: {score}/5{Colors.ENDC}")
                
                log(f"\nAverage Score: {avg_score:.2f}/5\n")
                log(f"{'-'*80}\n\n")
                print(f"{Colors.OKGREEN}✓ Evaluation complete - Average: {avg_score:.2f}/5{Colors.ENDC}\n")
                print()
                
            except Exception as e:
                log(f"❌ Evaluation failed: {e}\n\n")
                print(f"{Colors.FAIL}❌ Evaluation failed: {e}{Colors.ENDC}\n")
                continue
        
        # Step 3: Log statistics
        if mode == "images":
            log(f"STEP 4: LOGGING STATISTICS\n")
            print(f"{Colors.OKCYAN}⏳ Step 4/5: Logging statistics...{Colors.ENDC}")
        else:
            log(f"STEP 3: LOGGING STATISTICS\n")
            print(f"{Colors.OKCYAN}⏳ Step 3/4: Logging statistics...{Colors.ENDC}")
        
        try:
            if mode == "images":
                tracker.log_iteration(iteration, iteration, cover_eval, page_evals)
            else:
                tracker.log_iteration(mode, iteration, iteration, evaluation)
            log(f"✓ Statistics logged to CSV\n")
            log(f"{'-'*80}\n\n")
            print(f"{Colors.OKGREEN}✓ Statistics logged{Colors.ENDC}\n")
        except Exception as e:
            log(f"❌ Logging failed: {e}\n\n")
            print(f"{Colors.FAIL}❌ Logging failed: {e}{Colors.ENDC}\n")
        
        # Step 4: Improve prompt (skip on last iteration or if all scores are 4+)
        if iteration < iterations:
            # Check if all scores are 4 or above
            if mode == "images":
                all_above_threshold = (
                    cover_eval.all_scores_above_threshold(threshold=4)
                    and all(e.all_scores_above_threshold(threshold=4) for e in page_evals)
                )
                step_num = "5"
            else:
                all_above_threshold = evaluation.all_scores_above_threshold(threshold=4)
                step_num = "4"
            
            if all_above_threshold:
                log(f"STEP {step_num}: SKIPPED (All scores are 4 or above - optimization complete!)\n")
                log(f"{'-'*80}\n\n")
                log(f"🎉 OPTIMIZATION COMPLETE - All categories scored 4 or above!\n")
                log(f"Stopping optimization early at iteration {iteration}/{iterations}\n\n")
                print(f"{Colors.OKGREEN}✓ All scores are 4 or above - optimization complete!{Colors.ENDC}")
                print(f"{Colors.OKGREEN}Stopping optimization early at iteration {iteration}/{iterations}{Colors.ENDC}\n")
                break
            else:
                log(f"STEP {step_num}: IMPROVING PROMPT{'S' if mode == 'images' else ''}\n")
                print(f"{Colors.OKCYAN}⏳ Step {step_num}/{step_num}: Generating improved prompt{'s' if mode == 'images' else ''}...{Colors.ENDC}")
                
                # Log improvement suggestions
                if mode == "images":
                    cover_suggestions = cover_eval.get_improvement_suggestions()
                    if cover_suggestions:
                        log(f"COVER IMPROVEMENT SUGGESTIONS:\n")
                        for category, suggestion in cover_suggestions.items():
                            log(f"  {category}: {suggestion}\n")
                        log(f"\n")
                else:
                    improvement_suggestions = evaluation.get_improvement_suggestions()
                    if improvement_suggestions:
                        log(f"IMPROVEMENT SUGGESTIONS:\n")
                        for category, suggestion in improvement_suggestions.items():
                            log(f"  {category}: {suggestion}\n")
                        log(f"\n")
                
                try:
                    if mode == "images":
                        # Improve both cover and page system instructions
                        avg_page_eval = page_evals[0]  # Use first page as representative
                        
                        improved_cover_prompt = await suggest_cover_prompt_improvements(
                            current_cover_prompt, cover_eval, gemini_client
                        )
                        improved_page_prompt = await suggest_page_prompt_improvements(
                            current_page_prompt, avg_page_eval, gemini_client
                        )
                        
                        log(f"✓ System instruction improvements successful\n")
                        log(f"IMPROVED COVER INSTRUCTION:\n{improved_cover_prompt}\n")
                        log(f"IMPROVED PAGE INSTRUCTION:\n{improved_page_prompt}\n")
                        log(f"{'-'*80}\n\n")
                        
                        # Save improved system instructions
                        cover_prompt_template_path.write_text(improved_cover_prompt, encoding="utf-8")
                        page_prompt_template_path.write_text(improved_page_prompt, encoding="utf-8")
                        current_cover_prompt = improved_cover_prompt
                        current_page_prompt = improved_page_prompt
                        
                        print(f"{Colors.OKGREEN}✓ System instructions improved and saved{Colors.ENDC}\n")
                    else:
                        # Improve single prompt
                        improved_prompt = await suggest_improvements(
                            current_prompt, evaluation, mode, gemini_client
                        )
                        
                        log(f"✓ Prompt improvement successful\n")
                        log(f"IMPROVED PROMPT:\n{improved_prompt}\n")
                        log(f"{'-'*80}\n\n")
                        
                        # Save improved prompt
                        prompt_template_path.write_text(improved_prompt, encoding="utf-8")
                        current_prompt = improved_prompt
                        
                        print(f"{Colors.OKGREEN}✓ Prompt improved and saved{Colors.ENDC}\n")
                    
                except Exception as e:
                    log(f"❌ Improvement failed: {e}\n")
                    log(f"⚠️  Continuing with current prompt{'s' if mode == 'images' else ''}\n\n")
                    print(f"{Colors.FAIL}❌ Improvement failed: {e}{Colors.ENDC}")
                    print(f"{Colors.WARNING}⚠️  Continuing with current prompt{'s' if mode == 'images' else ''}{Colors.ENDC}\n")
        else:
            step_num = "5" if mode == "images" else "4"
            log(f"STEP {step_num}: SKIPPED (Final iteration)\n")
            log(f"{'-'*80}\n\n")
            print(f"{Colors.OKBLUE}ℹ️  Final iteration - skipping prompt improvement{Colors.ENDC}\n")
        
        # Small delay between iterations
        if iteration < iterations:
            await asyncio.sleep(1)
    
    # Final summary
    log(f"\n{'='*80}\n")
    log(f"OPTIMIZATION COMPLETE\n")
    log(f"Completed: {datetime.now().isoformat()}\n")
    log(f"Total Iterations: {iterations}\n")
    if mode == "images":
        log(f"Final prompts saved to:\n")
        log(f"  - {cover_prompt_template_path}\n")
        log(f"  - {page_prompt_template_path}\n")
    else:
        log(f"Final prompt saved to: {prompt_template_path}\n")
    log(f"Statistics saved to: {statistics_path}\n")
    log(f"{'='*80}\n")
    log_file.close()
    
    print(f"\n{Colors.OKGREEN}{'='*80}{Colors.ENDC}")
    print(f"{Colors.OKGREEN}🎉 Optimization complete!{Colors.ENDC}")
    print(f"{Colors.OKGREEN}{'='*80}{Colors.ENDC}")
    print(f"{Colors.OKGREEN}✓ Completed {iterations} iterations{Colors.ENDC}")
    print(f"{Colors.OKGREEN}✓ Statistics saved to: {statistics_path}{Colors.ENDC}")
    if mode == "images":
        print(f"{Colors.OKGREEN}✓ Cover prompt saved to: {cover_prompt_template_path}{Colors.ENDC}")
        print(f"{Colors.OKGREEN}✓ Page prompt saved to: {page_prompt_template_path}{Colors.ENDC}")
    else:
        print(f"{Colors.OKGREEN}✓ Final prompt saved to: {prompt_template_path}{Colors.ENDC}")
    print(f"{Colors.OKGREEN}✓ Log file saved to: {log_path}{Colors.ENDC}\n")


def main():
    """Parse arguments and run optimization."""
    parser = argparse.ArgumentParser(
        description="Optimize story generation prompts through iterative evaluation"
    )
    
    parser.add_argument(
        "--mode",
        choices=["outline", "pages", "images"],
        required=True,
        help="Type of prompt to optimize: outline (story structure), pages (page text), images (image generation)"
    )
    
    parser.add_argument(
        "--iterations",
        type=int,
        default=10,
        help="Number of optimization iterations (default: 10)"
    )
    
    parser.add_argument(
        "--child-name",
        default="Alex",
        help="Child's name for test story (default: Alex)"
    )
    
    parser.add_argument(
        "--child-age",
        type=int,
        default=6,
        help="Child's age for test story (default: 6)"
    )
    
    parser.add_argument(
        "--reference-image",
        type=str,
        default=None,
        help="Path to reference image of child (required for --mode images)"
    )
    
    parser.add_argument(
        "--title",
        default="The Magical Adventure",
        help="Story title (default: The Magical Adventure)"
    )
    
    parser.add_argument(
        "--summary",
        default="A child discovers a magical object and goes on an adventure",
        help="Story summary/theme"
    )
    
    parser.add_argument(
        "--num-pages",
        type=int,
        default=5,
        help="Number of pages (default: 5)"
    )
    
    parser.add_argument(
        "--language",
        default="English",
        help="Story language (default: English)"
    )
    
    parser.add_argument(
        "--illustration-style",
        default="colorful watercolor illustration",
        help="Illustration style for images mode (default: colorful watercolor illustration)"
    )
    
    args = parser.parse_args()
    
    # Validate reference image path for images mode
    reference_image_path = None
    if args.mode == "images":
        if not args.reference_image:
            print("Error: --reference-image is required when --mode is 'images'")
            sys.exit(1)
        reference_image_path = Path(args.reference_image)
        if not reference_image_path.exists():
            print(f"Error: Reference image not found: {reference_image_path}")
            sys.exit(1)
    
    # Run optimization
    asyncio.run(optimize_prompts(
        mode=args.mode,
        iterations=args.iterations,
        child_name=args.child_name,
        child_age=args.child_age,
        title=args.title,
        summary=args.summary,
        num_pages=args.num_pages,
        language=args.language,
        illustration_style=args.illustration_style,
        reference_image_path=reference_image_path,
    ))


if __name__ == "__main__":
    main()

