#!/usr/bin/env python3
"""
Author: Vasiliy Zdanovskiy
email: vasilyvz@gmail.com

Comprehensive pipeline for testing all MCP Proxy Adapter modes.
For each mode:
- Generates server and client configs using CLI generator
- Starts proxy and server
- Tests all commands including queue commands
"""

import asyncio
import json
import subprocess
import sys
import time
import uuid
from pathlib import Path
from typing import Any, Dict, List, Optional

from mcp_proxy_adapter.client.jsonrpc_client.client import JsonRpcClient
from tests.config import TEST_MODES
from tests.utils.config_utils import BASE_DIR, PROXY_PORT
from tests.utils.process_utils import (
    start_proxy,
    stop_proxy,
    start_server,
    stop_server,
    wait_for_port,
)
from tests.utils.health_utils import check_server_health, check_proxy_registration
from tests.utils.logging_utils import log_result, get_results, clear_results


# Certificate paths
CERTS_DIR = BASE_DIR / "mtls_certificates"
SERVER_CERT = str(CERTS_DIR / "server" / "test-server.crt")
SERVER_KEY = str(CERTS_DIR / "server" / "test-server.key")
SERVER_CA = str(CERTS_DIR / "ca" / "ca.crt")
CLIENT_CERT = str(CERTS_DIR / "client" / "test-client.crt")
CLIENT_KEY = str(CERTS_DIR / "client" / "test-client.key")
CLIENT_CA = str(CERTS_DIR / "ca" / "ca.crt")

CONFIGS_DIR = BASE_DIR / "mcp_proxy_adapter" / "examples" / "full_application" / "configs"
CLIENT_CONFIGS_DIR = BASE_DIR / "tests" / "client_configs"
CLIENT_CONFIGS_DIR.mkdir(parents=True, exist_ok=True)

# Starting port for test servers
START_PORT = 5000
PROXY_START_PORT = 6000


def kill_process_on_port(port: int) -> None:
    """Kill process occupying a port using kill -9."""
    try:
        # Find process using the port
        result = subprocess.run(
            ["lsof", "-ti", f":{port}"],
            capture_output=True,
            text=True
        )
        if result.returncode == 0 and result.stdout.strip():
            pids = result.stdout.strip().split("\n")
            for pid in pids:
                if pid:
                    subprocess.run(["kill", "-9", pid], capture_output=True)
    except Exception:
        pass


def generate_server_config(mode_spec: dict, server_port: int, proxy_port: int) -> Path:
    """Generate server configuration using CLI generator."""
    config_name = mode_spec["config"].split("/")[-1]
    config_path = CONFIGS_DIR / config_name
    
    CONFIGS_DIR.mkdir(parents=True, exist_ok=True)
    
    # Kill any process on the port before generating config
    kill_process_on_port(server_port)
    kill_process_on_port(proxy_port)
    
    # Use adapter-cfg-gen command if available, otherwise use module
    try:
        import shutil
        if shutil.which("adapter-cfg-gen"):
            cmd = ["adapter-cfg-gen"]
        else:
            cmd = [sys.executable, "-m", "mcp_proxy_adapter.cli.commands.config_generate"]
    except Exception:
        cmd = [sys.executable, "-m", "mcp_proxy_adapter.cli.commands.config_generate"]
    
    cmd.extend([
        "--protocol", mode_spec["protocol"],
        "--out", str(config_path),
        "--with-proxy",
        "--server-port", str(server_port),
        "--registration-port", str(proxy_port),
    ])
    
    # Add SSL certificates for https/mtls
    if mode_spec["protocol"] in ("https", "mtls"):
        cmd.extend([
            "--server-cert-file", SERVER_CERT,
            "--server-key-file", SERVER_KEY,
        ])
        if mode_spec["protocol"] == "mtls":
            cmd.extend([
                "--server-ca-cert-file", SERVER_CA,
                "--registration-cert-file", CLIENT_CERT,
                "--registration-key-file", CLIENT_KEY,
                "--registration-ca-cert-file", CLIENT_CA,
            ])
        else:
            cmd.extend([
                "--registration-cert-file", CLIENT_CERT,
                "--registration-key-file", CLIENT_KEY,
            ])
    
    # Add authentication
    if mode_spec.get("token"):
        cmd.append("--use-token")
        if "Roles" in mode_spec["name"]:
            cmd.append("--use-roles")
    
    # Add queue manager config
    cmd.extend([
        "--queue-enabled",
        "--max-queue-size", "1000",
        "--per-job-type-limits", "command_execution:100,data_processing:50,api_call:200",
    ])
    
    try:
        result = subprocess.run(cmd, capture_output=True, text=True, check=True)
        return config_path
    except subprocess.CalledProcessError as e:
        print(f"❌ Error generating server config: {e.stderr}", file=sys.stderr)
        raise


def generate_client_config(mode_spec: dict, port: int) -> Path:
    """Generate client configuration for testing."""
    config_name = f"client_{mode_spec['name'].lower().replace(' ', '_').replace('+', '_')}.json"
    config_path = CLIENT_CONFIGS_DIR / config_name
    
    # Client config matches server protocol
    protocol = mode_spec["protocol"]
    host = "localhost"
    
    client_config = {
        "protocol": protocol,
        "host": host,
        "port": port,
    }
    
    # Add SSL/TLS settings
    if protocol in ("https", "mtls"):
        client_config["cert"] = CLIENT_CERT
        client_config["key"] = CLIENT_KEY
        if protocol == "mtls":
            client_config["ca"] = CLIENT_CA
        client_config["check_hostname"] = False
    
    # Add authentication
    if mode_spec.get("token"):
        client_config["token"] = mode_spec["token"]
        client_config["token_header"] = "X-API-Key"
    
    config_path.write_text(json.dumps(client_config, indent=2), encoding="utf-8")
    return config_path


async def test_basic_commands(client: JsonRpcClient, mode_name: str) -> bool:
    """Test basic commands (echo, health, help)."""
    print(f"  🧪 Testing basic commands...")
    
    try:
        # Test echo
        result = await client.echo("Hello from pipeline")
        if not result.get("success"):
            log_result(mode_name, "command_echo", "FAIL", str(result))
            return False
        log_result(mode_name, "command_echo", "PASS")
        
        # Test health - use health endpoint directly (not a JSON-RPC command)
        try:
            health_data = await client.health()
            if health_data.get("status") == "ok":
                log_result(mode_name, "command_health", "PASS")
            else:
                log_result(mode_name, "command_health", "FAIL", str(health_data))
                return False
        except Exception as e:
            log_result(mode_name, "command_health", "FAIL", str(e))
            return False
        
        # Test help
        result = await client.help()
        if not result.get("success"):
            log_result(mode_name, "command_help", "FAIL", str(result))
            return False
        log_result(mode_name, "command_help", "PASS")
        
        return True
    except Exception as e:
        log_result(mode_name, "basic_commands", "FAIL", str(e))
        import traceback
        traceback.print_exc()
        return False


async def test_queue_commands(client: JsonRpcClient, mode_name: str) -> bool:
    """Test queue commands comprehensively."""
    print(f"  🧪 Testing queue commands...")
    
    try:
        # Test queue_health
        response = await client.jsonrpc_call("queue_health", {})
        result = client._extract_result(response)
        if not result.get("success"):
            log_result(mode_name, "queue_health", "FAIL", str(result))
            return False
        log_result(mode_name, "queue_health", "PASS")
        
        # Clean up any existing jobs
        response = await client.jsonrpc_call("queue_list_jobs", {})
        list_result = client._extract_result(response)
        if list_result.get("success"):
            jobs = list_result.get("data", {}).get("jobs", [])
            for job in jobs:
                job_id = job.get("job_id")
                if job_id:
                    await client.jsonrpc_call("queue_delete_job", {"job_id": job_id})
        
        # Test queue_add_job - command_execution
        job_id_1 = f"test_cmd_{uuid.uuid4().hex[:8]}"
        response = await client.jsonrpc_call(
            "queue_add_job",
            {
                "job_type": "command_execution",
                "job_id": job_id_1,
                "params": {
                    "command": "echo",
                    "params": {"message": "Test command execution"}
                }
            }
        )
        result = client._extract_result(response)
        if not result.get("success"):
            log_result(mode_name, "queue_add_job_cmd", "FAIL", str(result))
            return False
        log_result(mode_name, "queue_add_job_cmd", "PASS")
        
        # Test queue_add_job - data_processing
        job_id_2 = f"test_data_{uuid.uuid4().hex[:8]}"
        response = await client.jsonrpc_call(
            "queue_add_job",
            {
                "job_type": "data_processing",
                "job_id": job_id_2,
                "params": {
                    "data": {"test": "data"},
                    "operation": "process"
                }
            }
        )
        result = client._extract_result(response)
        if not result.get("success"):
            log_result(mode_name, "queue_add_job_data", "FAIL", str(result))
            return False
        log_result(mode_name, "queue_add_job_data", "PASS")
        
        # Test queue_add_job - api_call
        job_id_3 = f"test_api_{uuid.uuid4().hex[:8]}"
        response = await client.jsonrpc_call(
            "queue_add_job",
            {
                "job_type": "api_call",
                "job_id": job_id_3,
                "params": {
                    "url": "http://localhost:8080/health",
                    "method": "GET"
                }
            }
        )
        result = client._extract_result(response)
        if not result.get("success"):
            log_result(mode_name, "queue_add_job_api", "FAIL", str(result))
            return False
        log_result(mode_name, "queue_add_job_api", "PASS")
        
        # Test queue_get_job_status
        response = await client.jsonrpc_call("queue_get_job_status", {"job_id": job_id_1})
        result = client._extract_result(response)
        if not result.get("success"):
            log_result(mode_name, "queue_get_job_status", "FAIL", str(result))
            return False
        log_result(mode_name, "queue_get_job_status", "PASS")
        
        # Test queue_list_jobs
        response = await client.jsonrpc_call("queue_list_jobs", {})
        result = client._extract_result(response)
        if not result.get("success"):
            log_result(mode_name, "queue_list_jobs", "FAIL", str(result))
            return False
        jobs = result.get("data", {}).get("jobs", [])
        if len(jobs) < 3:
            log_result(mode_name, "queue_list_jobs", "FAIL", f"Expected at least 3 jobs, got {len(jobs)}")
            return False
        log_result(mode_name, "queue_list_jobs", "PASS", f"Found {len(jobs)} jobs")
        
        # Test queue_start_job
        response = await client.jsonrpc_call("queue_start_job", {"job_id": job_id_1})
        result = client._extract_result(response)
        if not result.get("success"):
            log_result(mode_name, "queue_start_job", "FAIL", str(result))
            return False
        log_result(mode_name, "queue_start_job", "PASS")
        
        # Wait a bit for job to start
        await asyncio.sleep(1)
        
        # Test queue_get_job_status after start
        response = await client.jsonrpc_call("queue_get_job_status", {"job_id": job_id_1})
        result = client._extract_result(response)
        if not result.get("success"):
            log_result(mode_name, "queue_get_job_status_after_start", "FAIL", str(result))
            return False
        status = result.get("data", {}).get("status")
        log_result(mode_name, "queue_get_job_status_after_start", "PASS", f"Status: {status}")
        
        # Test queue_stop_job
        response = await client.jsonrpc_call("queue_stop_job", {"job_id": job_id_1})
        result = client._extract_result(response)
        if not result.get("success"):
            log_result(mode_name, "queue_stop_job", "FAIL", str(result))
            return False
        log_result(mode_name, "queue_stop_job", "PASS")
        
        # Test queue_delete_job
        for job_id in [job_id_1, job_id_2, job_id_3]:
            response = await client.jsonrpc_call("queue_delete_job", {"job_id": job_id})
            result = client._extract_result(response)
            if not result.get("success"):
                log_result(mode_name, f"queue_delete_job_{job_id}", "FAIL", str(result))
                return False
        log_result(mode_name, "queue_delete_job", "PASS")
        
        return True
    except Exception as e:
        log_result(mode_name, "queue_commands", "FAIL", str(e))
        import traceback
        traceback.print_exc()
        return False


async def test_mode(mode_spec: dict, server_port: int, proxy_port: int) -> bool:
    """Test a single mode comprehensively."""
    mode_name = mode_spec["name"]
    print(f"\n{'=' * 80}")
    print(f"Testing: {mode_name}")
    print(f"Server port: {server_port}, Proxy port: {proxy_port}")
    print(f"{'=' * 80}")
    
    server_process: Optional[subprocess.Popen] = None
    proxy_process: Optional[subprocess.Popen] = None
    
    try:
        # Kill processes on ports before starting
        kill_process_on_port(server_port)
        kill_process_on_port(proxy_port)
        await asyncio.sleep(0.5)
        
        # Generate server config
        print("📝 Generating server config...")
        server_config_path = generate_server_config(mode_spec, server_port, proxy_port)
        log_result(mode_name, "config_generation_server", "PASS")
        
        # Generate client config
        print("📝 Generating client config...")
        client_config_path = generate_client_config(mode_spec, server_port)
        log_result(mode_name, "config_generation_client", "PASS")
        
        # Load client config
        client_config = json.loads(client_config_path.read_text(encoding="utf-8"))
        
        # Update client config with actual port
        client_config["port"] = server_port
        
        # Determine URL based on protocol and port
        protocol = mode_spec["protocol"]
        scheme = "https" if protocol in ("https", "mtls") else "http"
        server_url = f"{scheme}://localhost:{server_port}"
        
        # Start proxy if needed
        if mode_spec.get("with_proxy", True):
            print("🚀 Starting proxy...")
            proxy_process = start_proxy("http", "localhost", proxy_port)
            if proxy_process:
                if wait_for_port("localhost", proxy_port, timeout=10):
                    log_result(mode_name, "proxy_start", "PASS")
                else:
                    log_result(mode_name, "proxy_start", "FAIL", "Proxy port not ready")
                    return False
            else:
                log_result(mode_name, "proxy_start", "FAIL", "Failed to start proxy")
                return False
        
        # Start server
        print("🚀 Starting server...")
        server_process = start_server(str(server_config_path), server_port)
        
        # Wait longer for server to start and check stderr for errors
        if not wait_for_port("localhost", server_port, timeout=20):
            # Check if process is still running
            if server_process and server_process.poll() is not None:
                # Process exited, read stderr
                try:
                    if server_process.stderr:
                        error_output = server_process.stderr.read(2048).decode("utf-8", errors="ignore")
                        if error_output:
                            print(f"   Server error output: {error_output[:500]}")
                except Exception:
                    pass
            log_result(mode_name, "server_start", "FAIL", "Server port not ready")
            return False
        log_result(mode_name, "server_start", "PASS")
        
        # Wait for server to fully initialize (queue manager, etc.)
        await asyncio.sleep(3)
        
        # Test health endpoint
        print("🏥 Testing health endpoint...")
        health_ok, health_data = check_server_health(
            server_url,
            mode_spec["use_ssl"],
            mode_spec.get("token"),
            mode_spec.get("cert_file"),
            mode_spec.get("key_file"),
            mode_spec.get("use_mtls", False),
        )
        if not health_ok:
            log_result(mode_name, "health_endpoint", "FAIL", str(health_data))
            return False
        log_result(mode_name, "health_endpoint", "PASS")
        
        # Create client
        print("🔌 Creating client...")
        client = JsonRpcClient(
            protocol=client_config["protocol"],
            host=client_config["host"],
            port=client_config["port"],
            token=client_config.get("token"),
            token_header=client_config.get("token_header", "X-API-Key"),
            cert=client_config.get("cert"),
            key=client_config.get("key"),
            ca=client_config.get("ca"),
            check_hostname=client_config.get("check_hostname", False),
        )
        
        try:
            # Test basic commands
            print("📋 Testing basic commands...")
            if not await test_basic_commands(client, mode_name):
                return False
            
            # Test queue commands
            print("📋 Testing queue commands...")
            if not await test_queue_commands(client, mode_name):
                return False
            
            log_result(mode_name, "all_tests", "PASS")
            return True
        finally:
            # Close client
            await client.close()
        
    except Exception as e:
        log_result(mode_name, "test_exception", "FAIL", str(e))
        import traceback
        traceback.print_exc()
        return False
    finally:
        # Cleanup - ensure processes are stopped
        if server_process:
            try:
                server_process.terminate()
                try:
                    server_process.wait(timeout=3)
                except subprocess.TimeoutExpired:
                    server_process.kill()
                    server_process.wait(timeout=2)
            except Exception:
                try:
                    server_process.kill()
                except Exception:
                    pass
            server_process = None
        
        if proxy_process:
            try:
                proxy_process.terminate()
                try:
                    proxy_process.wait(timeout=3)
                except subprocess.TimeoutExpired:
                    proxy_process.kill()
                    proxy_process.wait(timeout=2)
            except Exception:
                try:
                    proxy_process.kill()
                except Exception:
                    pass
            proxy_process = None
        
        # Wait for ports to be released
        await asyncio.sleep(1)
        
        # Kill processes on ports
        kill_process_on_port(server_port)
        kill_process_on_port(proxy_port)
        await asyncio.sleep(0.5)


def cleanup_test_configs() -> None:
    """Clean up all test configuration files."""
    print("🧹 Cleaning up test configs...")
    
    # Clean server configs
    for config_file in CONFIGS_DIR.glob("*.json"):
        if config_file.name.startswith(("http_", "https_", "mtls_")):
            try:
                config_file.unlink()
                print(f"   Deleted: {config_file.name}")
            except Exception:
                pass
    
    # Clean client configs
    for config_file in CLIENT_CONFIGS_DIR.glob("*.json"):
        try:
            config_file.unlink()
            print(f"   Deleted: {config_file.name}")
        except Exception:
            pass


def cleanup_ports(start_port: int, end_port: int) -> None:
    """Clean up processes on ports in range."""
    print(f"🧹 Cleaning up ports {start_port}-{end_port}...")
    for port in range(start_port, end_port + 1):
        kill_process_on_port(port)
    time.sleep(1)


async def main() -> int:
    """Main pipeline function."""
    print("🚀 Starting comprehensive MCP Proxy Adapter testing pipeline")
    print("=" * 80)
    
    clear_results()
    
    # Clean up test configs before starting
    cleanup_test_configs()
    
    # Clean up ports in range (5000-5100 for servers, 6000-6100 for proxies)
    cleanup_ports(5000, 5100)
    cleanup_ports(6000, 6100)
    
    # Ensure configs directory exists
    CONFIGS_DIR.mkdir(parents=True, exist_ok=True)
    CLIENT_CONFIGS_DIR.mkdir(parents=True, exist_ok=True)
    
    # Test each mode with dynamic ports
    results = []
    for i, mode_spec in enumerate(TEST_MODES):
        print(f"\n{'=' * 80}")
        print(f"Progress: {i + 1}/{len(TEST_MODES)} modes")
        print(f"{'=' * 80}")
        
        server_port = START_PORT + i
        proxy_port = PROXY_START_PORT + i
        
        success = await test_mode(mode_spec, server_port, proxy_port)
        results.append((mode_spec["name"], success))
        
        # Clean up ports after each test
        kill_process_on_port(server_port)
        kill_process_on_port(proxy_port)
        
        # Short pause between modes
        if i < len(TEST_MODES) - 1:
            await asyncio.sleep(1)
    
    # Print summary
    print(f"\n{'=' * 80}")
    print("TEST SUMMARY")
    print(f"{'=' * 80}")
    
    passed = sum(1 for _, success in results if success)
    total = len(results)
    
    for name, success in results:
        status = "✅ PASS" if success else "❌ FAIL"
        print(f"{status}: {name}")
    
    print(f"\n🎯 SUMMARY: {passed}/{total} modes passed")
    
    # Print detailed results
    all_results = get_results()
    by_mode = {}
    for r in all_results:
        mode = r["mode"]
        if mode not in by_mode:
            by_mode[mode] = {"PASS": 0, "FAIL": 0}
        status = r["status"]
        by_mode[mode][status] = by_mode[mode].get(status, 0) + 1
    
    print("\nDetailed results by mode:")
    for mode, counts in by_mode.items():
        pass_count = counts.get("PASS", 0)
        fail_count = counts.get("FAIL", 0)
        total_count = pass_count + fail_count
        status_icon = "✅" if fail_count == 0 else "❌"
        print(f"  {status_icon} {mode}: {pass_count}/{total_count} tests passed")
    
    if passed == total:
        print("\n🎉 ALL MODES PASSED!")
        return 0
    else:
        print("\n⚠️  Some modes failed")
        return 1


if __name__ == "__main__":
    sys.exit(asyncio.run(main()))

