#!/usr/bin/env python3
"""
Cross-Platform MCP Diagnostic Script for shebang-fix-test-mcp
Automatically diagnoses common MCP loading issues across all operating systems
Generated by KEN-MCP
"""

import sys
import os
import subprocess
import json
import ast
import platform
import shutil
from pathlib import Path
import re

# Ensure we don't pollute stdout ourselves
import logging
logging.basicConfig(level=logging.INFO, stream=sys.stderr)
logger = logging.getLogger(__name__)

class CrossPlatformDiagnostics:
    def __init__(self):
        self.server_path = Path(__file__).parent / "server.py"
        self.requirements_path = Path(__file__).parent / "requirements.txt"
        self.results = []
        self.errors = []
        self.warnings = []
        self.os_type = self._detect_os()
        self.python_commands = self._find_python_commands()
        self.best_python = self._select_best_python()
        
    def _detect_os(self):
        """Detect the operating system type"""
        system = platform.system().lower()
        if system == "windows":
            return "windows"
        elif system == "darwin":
            return "macos"
        elif system == "linux":
            # Check if running in WSL
            try:
                with open("/proc/version", "r") as f:
                    if "microsoft" in f.read().lower():
                        return "wsl"
            except (FileNotFoundError, PermissionError):
                pass
            return "linux"
        else:
            return "unknown"
    
    def _find_python_commands(self):
        """Find all available Python commands"""
        commands = []
        candidates = ["python3", "python", "py", "python3.12", "python3.11", "python3.10", "python3.9"]
        
        for cmd in candidates:
            if shutil.which(cmd):
                version = self._get_python_version(cmd)
                if version:
                    commands.append({"command": cmd, "version": version})
        return commands
    
    def _get_python_version(self, command):
        """Get version for a Python command"""
        try:
            result = subprocess.run([command, "--version"], capture_output=True, text=True, timeout=5)
            if result.returncode == 0:
                version_line = result.stdout.strip() or result.stderr.strip()
                if "Python" in version_line:
                    return version_line.split()[1]
        except:
            pass
        return None
    
    def _select_best_python(self):
        """Select the best Python command"""
        if not self.python_commands:
            return None
        # Prefer python3, then highest version
        for cmd_info in self.python_commands:
            if cmd_info["command"] == "python3":
                return cmd_info["command"]
        return self.python_commands[0]["command"]
        
    def add_result(self, category: str, test: str, passed: bool, details: str = ""):
        """Add a diagnostic result"""
        status = "✅" if passed else "❌"
        self.results.append(f"{status} [{category}] {test}")
        if details:
            self.results.append(f"   → {details}")
        if not passed:
            self.errors.append(f"[{category}] {test}: {details}")
            
    def add_warning(self, message: str):
        """Add a warning message"""
        self.warnings.append(f"⚠️  {message}")
        
    def run_diagnostics(self):
        """Run all diagnostic checks"""
        print("🔍 Cross-Platform MCP Diagnostics for shebang-fix-test-mcp")
        print("=" * 70)
        print(f"🖥️  Operating System: {self.os_type.upper()}")
        print(f"🐍 Best Python: {self.best_python or 'NONE FOUND'}")
        print("=" * 70)
        
        # Platform detection
        self.check_platform_compatibility()
        
        # Check if server.py exists
        if not self.server_path.exists():
            self.add_result("PATH", "server.py exists", False, "server.py not found")
            return
        
        self.add_result("PATH", "server.py exists", True)
        
        # Check for wrapper scripts (multiple formats)
        self.check_wrapper_scripts()
        
        # Run all checks
        self.check_python_compatibility()
        self.check_dependencies()
        self.check_server_syntax()
        self.check_stdout_pollution()
        self.check_logging_config()
        self.check_import_issues()
        self.test_server_execution()
        self.test_json_rpc_compliance()
        
        # Generate report
        self.generate_report()
    
    def check_platform_compatibility(self):
        """Check platform-specific compatibility"""
        if self.os_type == "unknown":
            self.add_result("PLATFORM", "OS detection", False, "Unknown operating system")
        else:
            self.add_result("PLATFORM", "OS detection", True, f"Detected: {self.os_type}")
        
        if not self.python_commands:
            self.add_result("PLATFORM", "Python availability", False, "No Python installation found")
        else:
            self.add_result("PLATFORM", "Python availability", True, 
                          f"Found {len(self.python_commands)} Python installation(s)")
    
    def check_wrapper_scripts(self):
        """Check for platform-specific wrapper scripts"""
        wrapper_files = ["run_server.py", "run_server.bat", "run_server.sh"]
        found_wrappers = []
        
        for wrapper_name in wrapper_files:
            wrapper_path = Path(__file__).parent / wrapper_name
            if wrapper_path.exists():
                found_wrappers.append(wrapper_name)
        
        if found_wrappers:
            self.add_result("WRAPPERS", "Wrapper scripts exist", True, 
                          f"Found: {', '.join(found_wrappers)}")
        else:
            self.add_result("WRAPPERS", "Wrapper scripts exist", False, 
                          "No wrapper scripts found - MCP may fail to connect")
    
    def check_python_compatibility(self):
        """Check Python version compatibility across all available installations"""
        if not self.python_commands:
            self.add_result("PYTHON", "Python installations", False, "No Python found")
            return
        
        compatible_count = 0
        for cmd_info in self.python_commands:
            try:
                version = cmd_info["version"]
                major, minor = map(int, version.split(".")[:2])
                if major == 3 and minor >= 10:
                    compatible_count += 1
                    self.add_result("PYTHON", f"{cmd_info['command']} compatibility", True, 
                                  f"Version {version} (compatible)")
                else:
                    self.add_result("PYTHON", f"{cmd_info['command']} compatibility", False, 
                                  f"Version {version} (requires Python 3.10+)")
            except (ValueError, IndexError):
                self.add_result("PYTHON", f"{cmd_info['command']} compatibility", False, 
                              f"Invalid version: {cmd_info['version']}")
        
        if compatible_count == 0:
            self.add_warning("No compatible Python versions found. Consider upgrading Python.")
        
    def check_python_version(self):
        """Check Python version compatibility"""
        version = sys.version_info
        version_str = f"{version.major}.{version.minor}.{version.micro}"
        
        if version.major >= 3 and version.minor >= 10:
            self.add_result("PYTHON", "Python version", True, f"Python {version_str}")
        else:
            self.add_result("PYTHON", "Python version", False, 
                          f"Python {version_str} (requires 3.10+)")
            
    def check_dependencies(self):
        """Check if all dependencies are installed"""
        try:
            import fastmcp
            self.add_result("DEPS", "FastMCP installed", True, 
                          f"Version: {getattr(fastmcp, '__version__', 'unknown')}")
        except ImportError:
            self.add_result("DEPS", "FastMCP installed", False, 
                          "Install with: pip install fastmcp")
            
        # Check other dependencies from requirements.txt
        if self.requirements_path.exists():
            try:
                with open(self.requirements_path, 'r') as f:
                    lines = f.readlines()
                
                for line in lines:
                    line = line.strip()
                    if not line or line.startswith('#'):
                        continue
                    
                    # Extract package name
                    dep_name = line.split('[')[0].split('>=')[0].split('==')[0].split('<')[0].strip()
                    if dep_name == 'fastmcp':  # Already checked
                        continue
                    
                    try:
                        __import__(dep_name.replace('-', '_'))
                        self.add_result("DEPS", f"Dependency: {dep_name}", True)
                    except ImportError:
                        self.add_result("DEPS", f"Dependency: {dep_name}", False, "Not installed")
            except Exception as e:
                self.add_warning(f"Failed to parse requirements.txt: {e}")
                
    def check_server_syntax(self):
        """Check server.py for syntax errors"""
        try:
            with open(self.server_path, 'r') as f:
                code = f.read()
            
            # Try to compile the code
            compile(code, str(self.server_path), 'exec')
            self.add_result("SYNTAX", "server.py syntax", True)
            
            # Parse AST for deeper analysis
            tree = ast.parse(code)
            self.analyze_ast(tree)
            
        except SyntaxError as e:
            self.add_result("SYNTAX", "server.py syntax", False, f"Line {e.lineno}: {e.msg}")
        except Exception as e:
            self.add_result("SYNTAX", "server.py syntax", False, str(e))
            
    def analyze_ast(self, tree):
        """Analyze AST for common issues"""
        class PrintChecker(ast.NodeVisitor):
            def __init__(self):
                self.print_calls = []
                
            def visit_Call(self, node):
                if isinstance(node.func, ast.Name) and node.func.id == 'print':
                    self.print_calls.append(node.lineno)
                self.generic_visit(node)
                
        checker = PrintChecker()
        checker.visit(tree)
        
        if checker.print_calls:
            self.add_result("STDOUT", "No print() statements", False, 
                          f"Found print() at lines: {', '.join(map(str, checker.print_calls))}")
            self.add_warning("print() statements will break MCP protocol! Use logger.info() instead")
        else:
            self.add_result("STDOUT", "No print() statements", True)
            
    def check_stdout_pollution(self):
        """Check for potential stdout pollution"""
        with open(self.server_path, 'r') as f:
            content = f.read()
            
        # Check for stdout redirect
        if "sys.stdout = sys.stderr" in content:
            self.add_result("STDOUT", "Stdout redirection present", True)
        else:
            self.add_warning("Consider adding stdout redirection during imports")
            
    def check_logging_config(self):
        """Check logging configuration"""
        with open(self.server_path, 'r') as f:
            content = f.read()
            
        if "logging.basicConfig" in content and "stream=sys.stderr" in content:
            self.add_result("LOGGING", "Logging to stderr", True)
        else:
            self.add_result("LOGGING", "Logging to stderr", False, 
                          "Ensure logging is configured to stderr")
            
    def check_import_issues(self):
        """Check for import issues"""
        try:
            # Temporarily redirect stdout
            old_stdout = sys.stdout
            sys.stdout = sys.stderr
            
            # Try to import the server module
            import importlib.util
            spec = importlib.util.spec_from_file_location("test_server", self.server_path)
            if spec and spec.loader:
                module = importlib.util.module_from_spec(spec)
                spec.loader.exec_module(module)
                
            sys.stdout = old_stdout
            self.add_result("IMPORT", "Server imports successfully", True)
            
        except Exception as e:
            sys.stdout = old_stdout
            self.add_result("IMPORT", "Server imports successfully", False, str(e))
            
    def test_server_execution(self):
        """Test direct server execution"""
        try:
            result = subprocess.run(
                [sys.executable, str(self.server_path)],
                capture_output=True,
                text=True,
                timeout=2,
                env={**os.environ, "MCP_TEST_MODE": "1"}
            )
            
            if result.returncode == 0 or "Server started" in result.stderr:
                self.add_result("EXEC", "Server starts without errors", True)
            else:
                self.add_result("EXEC", "Server starts without errors", False, 
                              f"Exit code: {result.returncode}")
                if result.stderr:
                    self.add_warning(f"Stderr: {result.stderr[:200]}")
                    
        except subprocess.TimeoutExpired:
            # Timeout is expected for a running server
            self.add_result("EXEC", "Server starts without errors", True, 
                          "Server running (timeout expected)")
        except Exception as e:
            self.add_result("EXEC", "Server starts without errors", False, str(e))
            
    def test_json_rpc_compliance(self):
        """Test JSON-RPC compliance"""
        try:
            # Create a simple JSON-RPC request
            test_request = {
                "jsonrpc": "2.0",
                "method": "initialize",
                "params": {"capabilities": {}},
                "id": 1
            }
            
            # Run server with input
            proc = subprocess.Popen(
                [sys.executable, str(self.server_path)],
                stdin=subprocess.PIPE,
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
                text=True
            )
            
            # Send request
            if proc.stdin:
                proc.stdin.write(json.dumps(test_request) + '\n')
                proc.stdin.flush()
            
            # Wait briefly for response
            import time
            time.sleep(0.5)
            
            # Terminate
            proc.terminate()
            stdout, stderr = proc.communicate(timeout=1)
            
            # Check if stdout contains valid JSON
            if stdout and stdout.strip():
                try:
                    response = json.loads(stdout.strip().split('\n')[0])
                    if "jsonrpc" in response:
                        self.add_result("JSON-RPC", "Valid JSON-RPC response", True)
                    else:
                        self.add_result("JSON-RPC", "Valid JSON-RPC response", False, 
                                      "Missing jsonrpc field")
                except json.JSONDecodeError:
                    self.add_result("JSON-RPC", "Valid JSON-RPC response", False, 
                                  "Invalid JSON in stdout")
            else:
                self.add_result("JSON-RPC", "Valid JSON-RPC response", False, 
                              "No response received")
                
        except Exception as e:
            self.add_result("JSON-RPC", "Valid JSON-RPC response", False, str(e))
            
    def generate_report(self):
        """Generate final diagnostic report with platform-specific guidance"""
        print("\n📊 DIAGNOSTIC RESULTS")
        print("=" * 70)
        
        for result in self.results:
            print(result)
            
        if self.warnings:
            print("\n⚠️  WARNINGS")
            print("-" * 70)
            for warning in self.warnings:
                print(warning)
                
        if self.errors:
            print("\n❌ FAILURES SUMMARY")
            print("-" * 70)
            for error in self.errors:
                print(f"  • {error}")
                
            print("\n🔧 PLATFORM-SPECIFIC FIXES")
            print("-" * 70)
            self._generate_platform_specific_fixes()
                
        else:
            print("\n✅ All checks passed! MCP should load successfully.")
            
        print("\n📋 NEXT STEPS FOR YOUR PLATFORM")
        print("-" * 70)
        self._generate_platform_specific_steps()
        
    def _generate_platform_specific_fixes(self):
        """Generate platform-specific fix recommendations"""
        print(f"📱 Detected Platform: {self.os_type.upper()}")
        print()
        
        # General fixes
        if any("print()" in e for e in self.errors):
            print("🔧 Fix stdout pollution:")
            print("   • Replace all print() statements with logger.info()")
            print()
            
        if any("FastMCP" in e or "Dependencies" in e for e in self.errors):
            print("🔧 Install dependencies:")
            if self.os_type == "windows":
                if self.best_python:
                    print(f"   • {self.best_python} -m pip install -r requirements.txt")
                print("   • py -m pip install -r requirements.txt")
                print("   • pip install -r requirements.txt")
            else:
                if self.best_python:
                    print(f"   • {self.best_python} -m pip install -r requirements.txt")
                print("   • python3 -m pip install -r requirements.txt")
                print("   • pip3 install -r requirements.txt")
            print()
            
        if any("Python" in e for e in self.errors):
            print("🔧 Python compatibility issues:")
            if self.os_type == "windows":
                print("   • Install Python 3.10+ from python.org")
                print("   • Or use Microsoft Store Python")
                print("   • Or use Anaconda/Miniconda")
            elif self.os_type == "macos":
                print("   • Install via Homebrew: brew install python@3.10")
                print("   • Or download from python.org")
                print("   • Or use pyenv: pyenv install 3.10")
            elif self.os_type in ["linux", "wsl"]:
                print("   • Ubuntu/Debian: sudo apt install python3.10")
                print("   • CentOS/RHEL: sudo yum install python3.10")
                print("   • Or compile from source")
            print()
    
    def _generate_platform_specific_steps(self):
        """Generate platform-specific next steps"""
        python_cmd = self.best_python or "python3"
        
        if self.errors:
            print("1. Fix the issues listed above")
            print("2. Re-run diagnostics:")
            print(f"   {python_cmd} diagnose.py")
            print()
        
        print("3. Install dependencies:")
        if self.os_type == "windows":
            print(f"   {python_cmd} -m pip install -r requirements.txt")
        else:
            print(f"   {python_cmd} -m pip install -r requirements.txt")
        print()
        
        print("4. Test the server:")
        if self.os_type == "windows":
            print(f"   {python_cmd} server.py")
        else:
            print(f"   {python_cmd} server.py")
        print()
        
        print("5. Add to Claude Code:")
        project_path = Path(__file__).parent.absolute()
        
        if self.os_type == "windows":
            print(f"   claude mcp add shebang-fix-test-mcp \"{python_cmd} {project_path}\\run_server.py\"")
            print("   # Alternative if above fails:")
            print(f"   claude mcp add shebang-fix-test-mcp \"{project_path}\\run_server.bat\"")
        else:
            print(f"   claude mcp add shebang-fix-test-mcp \"{project_path}/run_server.py\"")
            print("   # Alternative if above fails:")
            print(f"   claude mcp add shebang-fix-test-mcp \"{python_cmd} {project_path}/run_server.py\"")
        print()
        
        print("6. Exit and restart Claude Code")
        print("   Type 'exit' or press Ctrl+C, then run 'claude' again")
        print()
        
        print("7. Verify connection:")
        print("   Use /mcp command in Claude Code")


def main():
    """Run cross-platform diagnostics"""
    diagnostics = CrossPlatformDiagnostics()
    diagnostics.run_diagnostics()


if __name__ == "__main__":
    main()
