"""Main UnifiedLogger class providing async interface to multiple logging backends."""

import asyncio
import logging
from collections.abc import Sequence
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timezone
from typing import Any

from .backend import LoggingBackend, LogRecord


class UnifiedLogger:
    """Unified logging interface supporting multiple backends asynchronously."""

    def __init__(
        self,
        name: str = "unified_logger",
        level: int = logging.INFO,
        format_type: str = "json",
        backends: Sequence[LoggingBackend] | None = None,
        enable_standard_logging: bool = True,
        max_workers: int = 4,
        batch_size: int = 100,
        batch_timeout: float = 5.0,
    ) -> None:
        """Initialize the UnifiedLogger.
        
        Args:
            name: Logger name
            level: Logging level (from logging module)
            format_type: Default formatter type ("json", "plain", etc.)
            backends: List of logging backends to use
            enable_standard_logging: Whether to also log to standard Python logging
            max_workers: Maximum worker threads for async operations
            batch_size: Maximum number of logs to batch together
            batch_timeout: Maximum time to wait before sending a batch
        """
        self.name = name
        self.level = level
        self.format_type = format_type
        self.backends = list(backends or [])
        self.enable_standard_logging = enable_standard_logging
        self.max_workers = max_workers
        self.batch_size = batch_size
        self.batch_timeout = batch_timeout

        # Internal state
        self._executor = ThreadPoolExecutor(max_workers=max_workers)
        self._pending_logs: list[LogRecord] = []
        self._batch_lock = asyncio.Lock()
        self._batch_task: asyncio.Task[None] | None = None
        self._shutdown = False

        # Standard Python logger setup
        if enable_standard_logging:
            self._standard_logger = logging.getLogger(name)
            self._standard_logger.setLevel(level)

            # Add handler if none exists
            if not self._standard_logger.handlers:
                handler = logging.StreamHandler()
                formatter = logging.Formatter(
                    '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
                )
                handler.setFormatter(formatter)
                self._standard_logger.addHandler(handler)
        else:
            self._standard_logger = None

    def add_backend(self, backend: LoggingBackend) -> None:
        """Add a logging backend.
        
        Args:
            backend: The backend to add
        """
        self.backends.append(backend)

    def remove_backend(self, backend_name: str) -> bool:
        """Remove a logging backend by name.
        
        Args:
            backend_name: Name of the backend to remove
            
        Returns:
            True if backend was found and removed, False otherwise
        """
        for i, backend in enumerate(self.backends):
            if backend.name == backend_name:
                self.backends.pop(i)
                return True
        return False

    async def connect_backends(self) -> dict[str, bool]:
        """Connect all backends.
        
        Returns:
            Dictionary mapping backend names to connection success status
        """
        results = {}
        for backend in self.backends:
            try:
                success = await backend.connect()
                results[backend.name] = success
                if not success:
                    await self._log_internal_error(
                        f"Backend {backend.name} failed to connect"
                    )
            except Exception as e:
                results[backend.name] = False
                await self._log_internal_error(
                    f"Failed to connect backend {backend.name}: {e}"
                )
        return results

    async def disconnect_backends(self) -> dict[str, bool]:
        """Disconnect all backends.
        
        Returns:
            Dictionary mapping backend names to disconnection success status
        """
        results = {}
        for backend in self.backends:
            try:
                await backend.disconnect()
                results[backend.name] = True
            except Exception as e:
                results[backend.name] = False
                await self._log_internal_error(
                    f"Failed to disconnect backend {backend.name}: {e}"
                )
        return results

    def _should_log(self, level: int) -> bool:
        """Check if message should be logged at given level."""
        return level >= self.level

    async def debug(
        self,
        message: str,
        extra: dict[str, Any] | None = None,
        exc_info: tuple[type, Exception, Any] | None = None
    ) -> None:
        """Log a debug message."""
        if self._should_log(logging.DEBUG):
            await self._log(logging.DEBUG, message, extra, exc_info)

    async def info(
        self,
        message: str,
        extra: dict[str, Any] | None = None,
        exc_info: tuple[type, Exception, Any] | None = None
    ) -> None:
        """Log an info message."""
        if self._should_log(logging.INFO):
            await self._log(logging.INFO, message, extra, exc_info)

    async def warning(
        self,
        message: str,
        extra: dict[str, Any] | None = None,
        exc_info: tuple[type, Exception, Any] | None = None
    ) -> None:
        """Log a warning message."""
        if self._should_log(logging.WARNING):
            await self._log(logging.WARNING, message, extra, exc_info)

    # Alias for warning
    warn = warning

    async def error(
        self,
        message: str,
        extra: dict[str, Any] | None = None,
        exc_info: tuple[type, Exception, Any] | None = None
    ) -> None:
        """Log an error message."""
        if self._should_log(logging.ERROR):
            await self._log(logging.ERROR, message, extra, exc_info)

    async def critical(
        self,
        message: str,
        extra: dict[str, Any] | None = None,
        exc_info: tuple[type, Exception, Any] | None = None
    ) -> None:
        """Log a critical message."""
        if self._should_log(logging.CRITICAL):
            await self._log(logging.CRITICAL, message, extra, exc_info)

    # Alias for critical
    fatal = critical

    async def log(
        self,
        level: int,
        message: str,
        extra: dict[str, Any] | None = None,
        exc_info: tuple[type, Exception, Any] | None = None
    ) -> None:
        """Log a message at the specified level."""
        if self._should_log(level):
            await self._log(level, message, extra, exc_info)

    async def _log(
        self,
        level: int,
        message: str,
        extra: dict[str, Any] | None = None,
        exc_info: tuple[type, Exception, Any] | None = None
    ) -> None:
        """Internal logging method."""
        if self._shutdown:
            return

        # Create log record
        record = LogRecord(
            timestamp=datetime.now(timezone.utc),
            level=level,
            message=message,
            logger_name=self.name,
            extra=extra or {},
            exc_info=exc_info
        )

        # Log to standard Python logging if enabled
        if self._standard_logger:
            self._standard_logger.log(level, message, extra=extra, exc_info=exc_info)

        # Add to batch for backend processing
        if self.backends:
            await self._add_to_batch(record)

    async def _add_to_batch(self, record: LogRecord) -> None:
        """Add log record to batch for processing."""
        async with self._batch_lock:
            self._pending_logs.append(record)

            # If batch size is reached, process immediately
            if len(self._pending_logs) >= self.batch_size:
                batch = self._pending_logs[:self.batch_size]
                self._pending_logs = self._pending_logs[self.batch_size:]
                # Process this batch immediately outside the lock
                asyncio.create_task(self._send_batch_to_backends(batch))

            # Start batch processing task if not already running (for timeout-based processing)
            if self._batch_task is None or self._batch_task.done():
                self._batch_task = asyncio.create_task(self._process_batch())

    async def _process_batch(self) -> None:
        """Process batched log records."""
        while not self._shutdown:
            # Wait for batch to fill or timeout
            await asyncio.sleep(self.batch_timeout)

            async with self._batch_lock:
                if not self._pending_logs:
                    continue

                # Take current batch
                batch = self._pending_logs[:self.batch_size]
                self._pending_logs = self._pending_logs[self.batch_size:]

            # Send batch to all backends
            await self._send_batch_to_backends(batch)

            # Stop if no more pending logs
            async with self._batch_lock:
                if not self._pending_logs:
                    self._batch_task = None
                    break

    async def _send_batch_to_backends(self, records: list[LogRecord]) -> None:
        """Send batch of records to all backends."""
        tasks = []
        for backend in self.backends:
            if backend.is_connected:
                task = asyncio.create_task(self._send_to_backend(backend, records))
                tasks.append(task)

        if tasks:
            # Wait for all backends to complete
            await asyncio.gather(*tasks, return_exceptions=True)

    async def _send_to_backend(
        self,
        backend: LoggingBackend,
        records: list[LogRecord]
    ) -> None:
        """Send records to a specific backend."""
        try:
            if len(records) == 1:
                await backend.send_log(records[0])
            else:
                await backend.send_logs_batch(records)
        except Exception as e:
            await self._log_internal_error(
                f"Failed to send logs to backend {backend.name}: {e}"
            )

    async def _log_internal_error(self, message: str) -> None:
        """Log internal errors to standard logger only."""
        if self._standard_logger:
            self._standard_logger.error(f"UnifiedLogger: {message}")

    async def flush(self) -> None:
        """Flush all pending logs immediately."""
        async with self._batch_lock:
            if self._pending_logs:
                batch = self._pending_logs[:]
                self._pending_logs.clear()
                await self._send_batch_to_backends(batch)

        # Cancel current batch task since we've processed everything
        if self._batch_task and not self._batch_task.done():
            self._batch_task.cancel()
            try:
                await self._batch_task
            except asyncio.CancelledError:
                pass
            self._batch_task = None

    async def shutdown(self) -> None:
        """Shutdown the logger and clean up resources."""
        self._shutdown = True

        # Flush any pending logs
        await self.flush()

        # Disconnect backends
        await self.disconnect_backends()

        # Shutdown executor
        self._executor.shutdown(wait=True)

    async def health_check(self) -> dict[str, bool]:
        """Check health of all backends.
        
        Returns:
            Dictionary mapping backend names to health status
        """
        results = {}
        for backend in self.backends:
            try:
                results[backend.name] = await backend.health_check()
            except Exception:
                results[backend.name] = False
        return results

    async def __aenter__(self) -> "UnifiedLogger":
        """Async context manager entry."""
        await self.connect_backends()
        return self

    async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
        """Async context manager exit."""
        await self.shutdown()

    def _get_level_name(self, level: int) -> str:
        """Get level name from level number.
        
        Args:
            level: Logging level number
            
        Returns:
            Level name as string
        """
        level_names = {
            logging.DEBUG: "DEBUG",
            logging.INFO: "INFO",
            logging.WARNING: "WARNING",
            logging.ERROR: "ERROR",
            logging.CRITICAL: "CRITICAL"
        }
        return level_names.get(level, f"LEVEL_{level}")
