"""
Progress tracking for scraping operations.

Key optimization: Uses a single progress file instead of per-property pickle saves.
Previous behavior saved progress after every property = hundreds of pickle operations.
"""

import asyncio
import json
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, List, Optional, Set
from threading import Lock
from loguru import logger

from ..utils.paths import get_output_dir


class ProgressTracker:
    """
    Track scraping progress with efficient persistence.

    Features:
    - Track fetched, processed, and failed URLs
    - Auto-save after configurable number of changes
    - Resume capability from saved state
    - Statistics tracking
    """

    def __init__(
        self,
        progress_file: Path | str | None = None,
        auto_save_interval: int = 50,
    ):
        self.progress_file = Path(progress_file) if progress_file else get_output_dir() / "progress.json"
        self.auto_save_interval = auto_save_interval

        self._fetched_urls: Set[str] = set()
        self._processed_urls: Set[str] = set()
        self._failed_urls: Dict[str, str] = {}  # url -> error message
        self._changes_since_save = 0
        self._lock = Lock()

        self._start_time: Optional[datetime] = None
        self._stats = {
            "total_fetched": 0,
            "total_processed": 0,
            "total_failed": 0,
        }

        # Load existing progress
        self._load()

        self.progress_file.parent.mkdir(parents=True, exist_ok=True)
        logger.debug(f"Initialized ProgressTracker: {progress_file}")

    def _load(self):
        """Load progress from disk."""
        if self.progress_file.exists():
            try:
                with open(self.progress_file, "r") as f:
                    data = json.load(f)

                self._fetched_urls = set(data.get("fetched_urls", []))
                self._processed_urls = set(data.get("processed_urls", []))
                self._failed_urls = data.get("failed_urls", {})
                self._stats = data.get("stats", self._stats)

                logger.info(
                    f"Loaded progress: {len(self._fetched_urls)} fetched, "
                    f"{len(self._processed_urls)} processed, "
                    f"{len(self._failed_urls)} failed"
                )
            except Exception as e:
                logger.warning(f"Failed to load progress: {e}")

    def save(self):
        """Save progress to disk."""
        with self._lock:
            data = {
                "fetched_urls": list(self._fetched_urls),
                "processed_urls": list(self._processed_urls),
                "failed_urls": self._failed_urls,
                "stats": self._stats,
                "last_saved": datetime.now(timezone.utc).isoformat(),
            }

            with open(self.progress_file, "w") as f:
                json.dump(data, f, indent=2)

            self._changes_since_save = 0
            logger.debug("Saved progress to disk")

    def _maybe_auto_save(self):
        """Auto-save if enough changes have accumulated."""
        self._changes_since_save += 1
        if self._changes_since_save >= self.auto_save_interval:
            self.save()

    def start(self):
        """Mark the start of a scraping session."""
        self._start_time = datetime.now(timezone.utc)
        logger.info("Started scraping session")

    def mark_fetched(self, url: str):
        """Mark a URL as fetched."""
        with self._lock:
            self._fetched_urls.add(url)
            self._stats["total_fetched"] += 1
            self._maybe_auto_save()

    def mark_processed(self, url: str):
        """Mark a URL as successfully processed."""
        with self._lock:
            self._processed_urls.add(url)
            self._stats["total_processed"] += 1
            self._maybe_auto_save()

    def mark_failed(self, url: str, error: str):
        """Mark a URL as failed with an error message."""
        with self._lock:
            self._failed_urls[url] = error
            self._stats["total_failed"] += 1
            self._maybe_auto_save()

    def is_fetched(self, url: str) -> bool:
        """Check if URL was already fetched."""
        with self._lock:
            return url in self._fetched_urls

    def is_processed(self, url: str) -> bool:
        """Check if URL was already processed."""
        with self._lock:
            return url in self._processed_urls

    def get_pending_urls(self, all_urls: List[str]) -> List[str]:
        """Get URLs that haven't been processed yet."""
        with self._lock:
            return [url for url in all_urls if url not in self._processed_urls]

    def get_failed_urls(self) -> Dict[str, str]:
        """Get all failed URLs with their error messages."""
        with self._lock:
            return self._failed_urls.copy()

    def retry_failed(self, url: str):
        """Remove a URL from the failed list to retry it."""
        with self._lock:
            if url in self._failed_urls:
                del self._failed_urls[url]
                self._maybe_auto_save()

    def clear(self):
        """Clear all progress."""
        with self._lock:
            self._fetched_urls.clear()
            self._processed_urls.clear()
            self._failed_urls.clear()
            self._stats = {
                "total_fetched": 0,
                "total_processed": 0,
                "total_failed": 0,
            }
            if self.progress_file.exists():
                self.progress_file.unlink()
            logger.info("Cleared all progress")

    def get_stats(self) -> dict:
        """Get progress statistics."""
        with self._lock:
            elapsed = None
            rate = None
            if self._start_time:
                elapsed = (datetime.now(timezone.utc) - self._start_time).total_seconds()
                if elapsed > 0:
                    rate = self._stats["total_processed"] / elapsed

            return {
                "fetched": len(self._fetched_urls),
                "processed": len(self._processed_urls),
                "failed": len(self._failed_urls),
                "pending": len(self._fetched_urls) - len(self._processed_urls),
                "elapsed_seconds": elapsed,
                "rate_per_second": rate,
                **self._stats,
            }

    def get_summary(self) -> str:
        """Get a human-readable progress summary."""
        stats = self.get_stats()
        parts = [
            f"Fetched: {stats['fetched']}",
            f"Processed: {stats['processed']}",
            f"Failed: {stats['failed']}",
            f"Pending: {stats['pending']}",
        ]
        if stats["rate_per_second"]:
            parts.append(f"Rate: {stats['rate_per_second']:.2f}/s")
        return " | ".join(parts)

    def __enter__(self):
        self.start()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.save()
