"""Mithril provider setup adapter.

Extracts Mithril-specific logic from the wizard to keep the wizard provider-agnostic
while preserving the UI and functionality.
"""

import os
from pathlib import Path
from typing import Any

from rich.console import Console
from rich.markup import escape

from flow._internal.config_loader import ConfigLoader
from flow._internal.config_manager import ConfigManager
from flow._internal.io.http import HttpClient
from flow.api.models import ValidationResult as ApiValidationResult
from flow.cli.utils.mask_utils import mask_api_key
from flow.cli.utils.shell_completion import CompletionCommand
from flow.core.setup_adapters import ConfigField, FieldType, ProviderSetupAdapter, ValidationResult
from flow.links import WebLinks
from flow.providers.mithril.core.constants import DEFAULT_REGION, VALID_REGIONS


class MithrilSetupAdapter(ProviderSetupAdapter):
    """Mithril provider setup adapter."""

    def __init__(self, console: Console | None = None):
        """Initialize Mithril setup adapter.

        Args:
            console: Rich console for output (creates one if not provided)
        """
        self.console = console or Console()
        # Canonical API URL env var; allow FLOW_API_URL as a last-resort dev fallback
        self.api_url = os.environ.get("MITHRIL_API_URL", os.environ.get("FLOW_API_URL", "https://api.mithril.ai"))
        self.config_path = Path.home() / ".flow" / "config.yaml"
        self._current_context = {}  # Store current wizard context

    def get_provider_name(self) -> str:
        """Get the provider name."""
        return "mithril"

    def get_configuration_fields(self) -> list[ConfigField]:
        """Get Mithril configuration fields."""
        return [
            ConfigField(
                name="api_key",
                field_type=FieldType.PASSWORD,
                required=True,
                mask_display=True,
                help_url=WebLinks.api_keys(),
                help_text="Get your API key from Mithril",
                default=None,
                display_name="API Key",
            ),
            ConfigField(
                name="project",
                field_type=FieldType.CHOICE,
                required=True,
                dynamic_choices=True,
                help_text="Select your Mithril project",
                depends_on=["api_key"],
                empty_choices_hint="Requires API key to list projects",
            ),
            ConfigField(
                name="default_ssh_key",
                field_type=FieldType.CHOICE,
                required=True,
                dynamic_choices=True,
                help_url=WebLinks.ssh_keys(),
                help_text=(
                    "Pick an existing project key or generate one on Mithril (recommended). "
                    "Keys marked '(local copy)' have a matching local private key. "
                    "Private keys are never uploaded."
                ),
                display_name="Default SSH Key",
                depends_on=["api_key", "project"],
                empty_choices_hint="Requires API key and project to list SSH keys",
            ),
            ConfigField(
                name="region",
                field_type=FieldType.CHOICE,
                required=False,
                choices=VALID_REGIONS,
                default=DEFAULT_REGION,
                help_text="Default region for instances",
                display_name="Default Region",
            ),
        ]

    def validate_field(
        self, field_name: str, value: str, context: dict[str, Any] | None = None
    ) -> ValidationResult:
        """Validate a single field value."""
        # Update current context if provided
        if context:
            self._current_context.update(context)

        if field_name == "api_key":
            return self._validate_api_key(value)
        elif field_name == "project":
            return self._validate_project(value)
        elif field_name == "default_ssh_key":
            return self._validate_ssh_key(value)
        elif field_name == "region":
            return self._validate_region(value)
        else:
            return ValidationResult(is_valid=False, message=f"Unknown field: {field_name}")

    def get_dynamic_choices(self, field_name: str, context: dict[str, Any]) -> list[str]:
        """Get dynamic choices for a field."""
        # Store the current context for use in validation
        self._current_context = context

        if field_name == "project":
            return self._get_project_choices(context.get("api_key"))
        elif field_name == "default_ssh_key":
            return self._get_ssh_key_choices(context.get("api_key"), context.get("project"))
        else:
            return []

    def detect_existing_config(self) -> dict[str, Any]:
        """Detect existing configuration from environment, files, etc."""
        # Use centralized manager for consistent detection/normalization
        manager = ConfigManager(self.config_path)
        detected = manager.detect_existing_config()

        # Always include region default if missing (UI shows something meaningful)
        if "region" not in detected:
            detected["region"] = DEFAULT_REGION

        return detected

    def save_configuration(self, config: dict[str, Any]) -> bool:
        """Save the final configuration using centralized ConfigWriter."""
        try:
            manager = ConfigManager(self.config_path)
            # Normalize and save using centralized manager
            payload = dict(config)
            payload.setdefault("provider", "mithril")
            saved = manager.save(payload)

            # Write canonical env script (no API key by default)
            manager.write_env_script(saved, include_api_key=False)

            # Set up shell completion automatically
            try:
                completion_cmd = CompletionCommand()
                shell = completion_cmd._detect_shell()
                if shell:
                    self.console.print(f"\n[dim]Setting up {shell} shell completion...[/dim]")
                    completion_cmd._install_completion(shell, None)
            except Exception:
                pass

            return True

        except Exception:
            return False

    def verify_configuration(self, config: dict[str, Any]) -> tuple[bool, str | None]:
        """Verify that the configuration works end-to-end."""
        try:
            # Set environment from config
            if "api_key" in config:
                os.environ["MITHRIL_API_KEY"] = config["api_key"]
            if "project" in config:
                os.environ["MITHRIL_PROJECT"] = config["project"]

            # Test API operation
            from flow import Flow

            client = Flow()
            client.list_tasks(limit=1)

            # Check billing status (non-blocking)
            try:
                http_client = HttpClient(
                    base_url=self.api_url,
                    headers={"Authorization": f"Bearer {config.get('api_key')}"},
                )
                billing_status = http_client.request("GET", "/v2/account/billing")
                if not billing_status.get("configured", False):
                    # Store billing status for completion message
                    # Expose as public attribute for UI to read without private access
                    self.billing_not_configured = True
                    self.console.print(
                        "\n[yellow]Note: Billing not configured yet. Set it up at:[/yellow]"
                    )
                    from flow.links import WebLinks
                    self.console.print(f"[cyan]{WebLinks.billing_settings()}[/cyan]")
            except:
                # Don't fail setup for billing check
                pass

            return True, None

        except Exception as e:
            return False, str(e)

    def get_welcome_message(self) -> tuple[str, list[str]]:
        """Get Mithril-specific welcome message."""
        return (
            "Welcome to Flow SDK Setup",
            [
                "Get and validate your API key",
                "Select your project",
                "Configure SSH access",
                "Verify everything works",
            ],
        )

    def get_completion_message(self) -> str:
        """Get Mithril-specific completion message."""
        return "Setup Complete! Your Flow SDK is configured and ready to run GPU workloads."

    # Private helper methods

    def _validate_api_key(self, api_key: str) -> ValidationResult:
        """Validate API key format and with API."""
        # Basic format validation
        if not api_key.startswith("fkey_") or len(api_key) < 20:
            return ValidationResult(
                is_valid=False,
                message="Invalid API key format. Expected: fkey_XXXXXXXXXXXXXXXXXXXXXXXX",
            )

        # API validation
        try:
            client = HttpClient(
                base_url=self.api_url,
                headers={"Authorization": f"Bearer {api_key}"},
            )
            client.request("GET", "/v2/projects", timeout_seconds=10.0)
            masked_key = mask_api_key(api_key)
            return ValidationResult(is_valid=True, display_value=masked_key)
        except Exception as e:
            return ValidationResult(is_valid=False, message=f"API validation failed: {e}")

    def _validate_project(self, project: str) -> ValidationResult:
        """Validate project name."""
        if not project or len(project.strip()) == 0:
            return ValidationResult(is_valid=False, message="Project name cannot be empty")
        return ValidationResult(is_valid=True, display_value=project)

    def _validate_ssh_key(self, ssh_key: str) -> ValidationResult:
        """Validate SSH key ID or handle generation requests."""
        if not ssh_key or len(ssh_key.strip()) == 0:
            return ValidationResult(is_valid=False, message="SSH key ID cannot be empty")

        # Handle platform auto-generation (recommended)
        if ssh_key == "_auto_":
            generated_key_id = self._generate_server_side_key()
            if generated_key_id:
                return ValidationResult(
                    is_valid=True, display_value=generated_key_id, processed_value=generated_key_id
                )
            return ValidationResult(is_valid=False, message="Failed to generate SSH key")

        # Handle generation options
        if ssh_key == "GENERATE_SERVER":
            self.console.print("\n[yellow]Generating SSH key on Mithril platform...[/yellow]")
            generated_key_id = self._generate_server_side_key()
            if generated_key_id:
                self.console.print(
                    f"[green]✓[/green] Successfully generated SSH key: {generated_key_id}"
                )
                self.console.print("[dim]Private key saved to ~/.flow/keys/ for SSH access[/dim]\n")
                return ValidationResult(
                    is_valid=True, display_value=generated_key_id, processed_value=generated_key_id
                )
            else:
                return ValidationResult(is_valid=False, message="Failed to generate SSH key")

        elif ssh_key == "GENERATE_LOCAL":
            self.console.print("\n[yellow]Generating SSH key locally...[/yellow]")
            generated_key_id = self._generate_local_key()
            if generated_key_id:
                self.console.print(
                    f"[green]✓[/green] Successfully generated SSH key: {generated_key_id}"
                )
                self.console.print("[dim]Key pair stored in ~/.flow/keys/[/dim]\n")
                return ValidationResult(
                    is_valid=True, display_value=generated_key_id, processed_value=generated_key_id
                )
            else:
                return ValidationResult(
                    is_valid=False, message="Failed to generate SSH key locally"
                )

        # Regular SSH key ID
        if ssh_key.startswith("sshkey_"):
            display_value = f"Platform key ({ssh_key[:14]}...)"
        elif ssh_key == "_auto_":
            display_value = "Deprecated (_auto_)"
        else:
            display_value = "Configured"
        return ValidationResult(is_valid=True, display_value=display_value)

    def _validate_region(self, region: str) -> ValidationResult:
        """Validate region."""
        valid_regions = VALID_REGIONS
        if region not in valid_regions:
            return ValidationResult(
                is_valid=False, message=f"Invalid region. Choose from: {', '.join(valid_regions)}"
            )
        return ValidationResult(is_valid=True, display_value=region)

    def _get_project_choices(self, api_key: str | None) -> list[str]:
        """Get available projects from API."""
        if not api_key:
            return []

        try:
            client = HttpClient(
                base_url=self.api_url,
                headers={"Authorization": f"Bearer {api_key}"},
            )
            projects = client.request("GET", "/v2/projects", timeout_seconds=10.0)
            return [proj["name"] for proj in projects if isinstance(projects, list)]
        except Exception:
            return []

    def _get_ssh_key_choices(self, api_key: str | None, project: str | None) -> list[str]:
        """Get available SSH keys from API plus generation options."""
        choices = []

        # Add generation options first
        choices.extend(
            [
                "GENERATE_SERVER|Generate on Mithril (recommended; saves key locally)",
                "GENERATE_LOCAL|Generate locally (uploads public key)",
            ]
        )

        if not api_key or not project:
            return choices

        try:
            client = HttpClient(
                base_url=self.api_url,
                headers={"Authorization": f"Bearer {api_key}"},
            )
            # Use MithrilApiClient for typed endpoints
            try:
                from flow.providers.mithril.api.client import MithrilApiClient as _Api

                api = _Api(client)
                projects = api.list_projects()
            except Exception:
                # Fallback to direct request on any import/runtime issue
                projects = client.request("GET", "/v2/projects", timeout_seconds=10.0)
            project_id = None
            for proj in projects:
                if proj.get("name") == project:
                    project_id = proj.get("fid")
                    break

            if not project_id:
                return choices

            # Get existing SSH keys via manager (normalized + cached)
            ssh_manager = None
            try:
                from flow.providers.mithril.resources.ssh import SSHKeyManager as _SSHKeyManager

                ssh_manager = _SSHKeyManager(client, project_id)
                ssh_keys = ssh_manager.list_keys()
            except Exception:
                # Fallback to raw API if manager import fails
                ssh_keys = client.request(
                    "GET", "/v2/ssh-keys", params={"project": project_id}, timeout_seconds=10.0
                )

            # Build a set of platform key IDs that have a matching local private key
            local_ids: set[str] = set()
            try:
                from pathlib import Path as _Path
                import json as _json
                meta_path = _Path.home() / ".flow" / "keys" / "metadata.json"
                if meta_path.exists():
                    data = _json.loads(meta_path.read_text())
                    if isinstance(data, dict):
                        for _kid, _info in data.items():
                            p = _Path((_info or {}).get("private_key_path", ""))
                            if p.exists():
                                local_ids.add(_kid)
            except Exception:
                pass

            # Iterate keys (supports dicts and SSHKeyModel instances)
            if isinstance(ssh_keys, list):
                for key in ssh_keys:
                    # Normalize fields for both dict and model cases
                    if isinstance(key, dict):
                        created_at = key.get("created_at", "")
                        public_key = key.get("public_key", "")
                        required = key.get("required") or key.get("is_required")
                        fid = key.get("fid")
                        key_name = key.get("name", "")
                    else:
                        # Pydantic model (SSHKeyModel)
                        created_at = getattr(key, "created_at", "")
                        public_key = getattr(key, "public_key", "")
                        required = getattr(key, "required", None) or getattr(key, "is_required", None)
                        fid = getattr(key, "fid", None)
                        key_name = getattr(key, "name", "")

                    fingerprint = self._extract_fingerprint(public_key or "")
                    required_flag = " (required)" if required else ""

                    # Determine if there is a matching local private key
                    has_local = False
                    try:
                        if fid:
                            if fid in local_ids:
                                has_local = True
                            elif ssh_manager is not None:
                                try:
                                    if ssh_manager.find_matching_local_key(fid):
                                        has_local = True
                                except Exception:
                                    has_local = False
                    except Exception:
                        has_local = False

                    name_display = f"{key_name}{required_flag}"
                    if has_local:
                        name_display = f"{name_display} (local copy)"

                    if fid:
                        choices.append(
                            f"{fid}|{name_display}|{created_at}|{fingerprint}"
                        )

            return choices
        except Exception:
            return choices

    def _generate_server_side_key(self) -> str | None:
        """Generate SSH key server-side."""
        try:
            # Get current config for API access
            config = self.detect_existing_config()
            # Check wizard context first (from get_dynamic_choices), then detected config, then env vars
            api_key = (
                self._current_context.get("api_key")
                or config.get("api_key")
                or os.environ.get("MITHRIL_API_KEY")
            )
            project = (
                self._current_context.get("project")
                or config.get("project")
                or os.environ.get("MITHRIL_PROJECT")
            )

            if not api_key or not project:
                self.console.print("[red]API key and project required for SSH key generation[/red]")
                return None

            # Set up client
            client = HttpClient(
                base_url=self.api_url,
                headers={"Authorization": f"Bearer {api_key}"},
            )

            # Get project ID
            projects = client.request("GET", "/v2/projects")
            project_id = None
            for proj in projects:
                if proj.get("name") == project:
                    project_id = proj.get("fid")
                    break

            if not project_id:
                self.console.print("[red]Could not resolve project ID[/red]")
                return None

            # Import SSH manager
            from flow.providers.mithril.resources.ssh import SSHKeyManager

            ssh_manager = SSHKeyManager(client, project_id)

            # Generate server-side key
            key_id = ssh_manager.generate_server_key()
            return key_id

        except Exception as e:
            self.console.print(
                f"[red]Error generating SSH key: {escape(type(e).__name__)}: {escape(str(e))}[/red]"
            )
            if hasattr(e, "response"):
                self.console.print(
                    f"[red]API Response: {escape(str(getattr(e, 'response', 'N/A')))}[/red]"
                )
            return None

    def _generate_local_key(self) -> str | None:
        """Generate SSH key locally."""
        try:
            # Get current config for API access
            config = self.detect_existing_config()
            # Check wizard context first (from get_dynamic_choices), then detected config, then env vars
            api_key = (
                self._current_context.get("api_key")
                or config.get("api_key")
                or os.environ.get("MITHRIL_API_KEY")
            )
            project = (
                self._current_context.get("project")
                or config.get("project")
                or os.environ.get("MITHRIL_PROJECT")
            )

            if not api_key or not project:
                self.console.print("[red]API key and project required for SSH key generation[/red]")
                return None

            # Set up client
            client = HttpClient(
                base_url=self.api_url,
                headers={"Authorization": f"Bearer {api_key}"},
            )

            # Get project ID
            projects = client.request("GET", "/v2/projects")
            project_id = None
            for proj in projects:
                if proj.get("name") == project:
                    project_id = proj.get("fid")
                    break

            if not project_id:
                self.console.print("[red]Could not resolve project ID[/red]")
                return None

            # Import SSH manager
            from flow.providers.mithril.resources.ssh import SSHKeyManager

            ssh_manager = SSHKeyManager(client, project_id)

            # Generate local key
            key_id = ssh_manager.generate_local_key()
            return key_id

        except Exception as e:
            self.console.print(
                f"[red]Error generating SSH key: {escape(type(e).__name__)}: {escape(str(e))}[/red]"
            )
            if hasattr(e, "response"):
                self.console.print(
                    f"[red]API Response: {escape(str(getattr(e, 'response', 'N/A')))}[/red]"
                )
            return None

    def _create_env_script(self, config: dict[str, Any]):
        """Create shell script with clean provider-specific environment variables.

        Clean, decisive approach: MITHRIL_* variables only.
        No legacy compatibility - users adapt to the right way.
        """
        env_script = self.config_path.parent / "env.sh"

        with open(env_script, "w") as f:
            f.write("#!/bin/bash\n")
            f.write("# Flow SDK Mithril provider environment variables\n")
            f.write("# Source this file: source ~/.flow/env.sh\n\n")

            # Project - provider-specific canonical naming only
            if "project" in config:
                f.write(f'export MITHRIL_PROJECT="{config["project"]}"\n')

            # Region - provider-specific canonical naming only
            if "region" in config:
                f.write(f'export MITHRIL_REGION="{config["region"]}"\n')

            # SSH keys - provider-specific naming only
            if "default_ssh_key" in config:
                f.write(f'export MITHRIL_SSH_KEYS="{config["default_ssh_key"]}"\n')

        env_script.chmod(0o600)

    def _extract_fingerprint(self, public_key: str) -> str:
        """Extract fingerprint from SSH public key.

        Args:
            public_key: SSH public key content

        Returns:
            Fingerprint string or empty string if extraction fails
        """
        if not public_key:
            return ""

        try:
            import base64
            import hashlib

            # SSH public keys format: <type> <base64-data> [comment]
            parts = public_key.strip().split()
            if len(parts) >= 2:
                # Decode base64 key data
                key_data = base64.b64decode(parts[1])
                # Calculate SHA256 hash
                sha256 = hashlib.sha256(key_data).digest()
                # Convert to base64 and format
                fingerprint = base64.b64encode(sha256).decode("utf-8").rstrip("=")
                # Return shortened fingerprint for display
                return f"SHA256:{fingerprint[:8]}..."
            return ""
        except Exception:
            return ""

    def _load_existing_config(self) -> dict[str, Any]:
        """Load existing configuration from file."""
        if not self.config_path.exists():
            return {}

        try:
            import yaml

            with open(self.config_path) as f:
                return yaml.safe_load(f) or {}
        except Exception:
            return {}
