"""Meraki MS (Switch) metrics collector."""

from __future__ import annotations

import asyncio
from typing import TYPE_CHECKING, Any

from ...core.constants import MSMetricName
from ...core.error_handling import validate_response_format, with_error_handling
from ...core.logging import get_logger
from ...core.logging_decorators import log_api_call
from ...core.logging_helpers import LogContext
from ...core.metrics import LabelName
from .base import BaseDeviceCollector

if TYPE_CHECKING:
    pass

logger = get_logger(__name__)


class MSCollector(BaseDeviceCollector):
    """Collector for Meraki MS (Switch) devices."""

    def _initialize_metrics(self) -> None:
        """Initialize MS-specific metrics."""
        # Switch port metrics
        self._switch_port_status = self.parent._create_gauge(
            MSMetricName.MS_PORT_STATUS,
            "Switch port status (1 = connected, 0 = disconnected)",
            labelnames=[LabelName.SERIAL, LabelName.NAME, LabelName.PORT_ID, LabelName.PORT_NAME],
        )

        self._switch_port_traffic = self.parent._create_gauge(
            MSMetricName.MS_PORT_TRAFFIC_BYTES,
            "Switch port traffic in bytes",
            labelnames=[
                LabelName.SERIAL,
                LabelName.NAME,
                LabelName.PORT_ID,
                LabelName.PORT_NAME,
                LabelName.DIRECTION,
            ],
        )

        # Switch power metrics
        self._switch_power = self.parent._create_gauge(
            MSMetricName.MS_POWER_USAGE_WATTS,
            "Switch power usage in watts",
            labelnames=[LabelName.SERIAL, LabelName.NAME, LabelName.MODEL],
        )

        # POE metrics
        self._switch_poe_port_power = self.parent._create_gauge(
            MSMetricName.MS_POE_PORT_POWER_WATTS,
            "Per-port POE power consumption in watt-hours (Wh)",
            labelnames=[LabelName.SERIAL, LabelName.NAME, LabelName.PORT_ID, LabelName.PORT_NAME],
        )

        self._switch_poe_total_power = self.parent._create_gauge(
            MSMetricName.MS_POE_TOTAL_POWER_WATTS,
            "Total POE power consumption for switch in watt-hours (Wh)",
            labelnames=[LabelName.SERIAL, LabelName.NAME, LabelName.MODEL, LabelName.NETWORK_ID],
        )

        self._switch_poe_budget = self.parent._create_gauge(
            MSMetricName.MS_POE_BUDGET_WATTS,
            "Total POE power budget for switch in watts",
            labelnames=[LabelName.SERIAL, LabelName.NAME, LabelName.MODEL, LabelName.NETWORK_ID],
        )

        self._switch_poe_network_total = self.parent._create_gauge(
            MSMetricName.MS_POE_NETWORK_TOTAL_WATTS,
            "Total POE power consumption for all switches in network in watt-hours (Wh)",
            labelnames=[LabelName.NETWORK_ID, LabelName.NETWORK_NAME],
        )

    @log_api_call("getDeviceSwitchPortsStatuses")
    @with_error_handling(
        operation="Collect MS device metrics",
        continue_on_error=True,
    )
    async def collect(self, device: dict[str, Any]) -> None:
        """Collect switch-specific metrics.

        Parameters
        ----------
        device : dict[str, Any]
            Switch device data.

        """
        serial = device["serial"]
        name = device.get("name", serial)
        model = device.get("model", "")
        network_id = device.get("networkId", "")

        try:
            # Get port statuses with timeout
            with LogContext(serial=serial, name=name):
                port_statuses = await asyncio.to_thread(
                    self.api.switch.getDeviceSwitchPortsStatuses,
                    serial,
                )
                port_statuses = validate_response_format(
                    port_statuses, expected_type=list, operation="getDeviceSwitchPortsStatuses"
                )

            for port in port_statuses:
                port_id = str(port.get("portId", ""))
                port_name = port.get("name", f"Port {port_id}")

                # Port status
                is_connected = 1 if port.get("status") == "Connected" else 0
                self._switch_port_status.labels(
                    serial=serial,
                    name=name,
                    port_id=port_id,
                    port_name=port_name,
                ).set(is_connected)

                # Traffic counters
                if "trafficInKbps" in port:
                    traffic_counters = port["trafficInKbps"]

                    if "recv" in traffic_counters:
                        self._switch_port_traffic.labels(
                            serial=serial,
                            name=name,
                            port_id=port_id,
                            port_name=port_name,
                            direction="rx",
                        ).set(traffic_counters["recv"] * 1000 / 8)  # Convert to bytes

                    if "sent" in traffic_counters:
                        self._switch_port_traffic.labels(
                            serial=serial,
                            name=name,
                            port_id=port_id,
                            port_name=port_name,
                            direction="tx",
                        ).set(traffic_counters["sent"] * 1000 / 8)  # Convert to bytes

            # Extract POE data from port statuses (POE data is included in port status)
            total_poe_consumption = 0

            for port in port_statuses:
                port_id = str(port.get("portId", ""))
                port_name = port.get("name", f"Port {port_id}")

                # Check if port has POE data
                poe_info = port.get("poe", {})
                if poe_info.get("isAllocated", False):
                    # Port is drawing POE power
                    power_used = port.get("powerUsageInWh", 0)

                    self._switch_poe_port_power.labels(
                        serial=serial,
                        name=name,
                        port_id=port_id,
                        port_name=port_name,
                    ).set(power_used)

                    total_poe_consumption += power_used
                else:
                    # Port is not drawing POE power
                    self._switch_poe_port_power.labels(
                        serial=serial,
                        name=name,
                        port_id=port_id,
                        port_name=port_name,
                    ).set(0)

            # Set switch-level POE total
            self._switch_poe_total_power.labels(
                serial=serial,
                name=name,
                model=model,
                network_id=network_id,
            ).set(total_poe_consumption)

            # Set total switch power usage (POE consumption is the main power draw)
            # This is an approximation - actual switch base power consumption varies by model
            self._switch_power.labels(
                serial=serial,
                name=name,
                model=model,
            ).set(total_poe_consumption)

            # Note: POE budget is not available via API, would need a lookup table by model

        except Exception:
            logger.exception(
                "Failed to collect switch metrics",
                serial=serial,
            )
