#!/usr/bin/env python3
"""
MCP Diagnostic Script for final-test-mcp
Automatically diagnoses common MCP loading issues
Generated by KEN-MCP
"""

import sys
import os
import subprocess
import json
import ast
from pathlib import Path
import re

# Ensure we don't pollute stdout ourselves
import logging
logging.basicConfig(level=logging.INFO, stream=sys.stderr)
logger = logging.getLogger(__name__)

class MCPDiagnostics:
    def __init__(self):
        self.server_path = Path(__file__).parent / "server.py"
        self.requirements_path = Path(__file__).parent / "requirements.txt"
        self.results = []
        self.errors = []
        self.warnings = []
        
    def add_result(self, category: str, test: str, passed: bool, details: str = ""):
        """Add a diagnostic result"""
        status = "✅" if passed else "❌"
        self.results.append(f"{status} [{category}] {test}")
        if details:
            self.results.append(f"   → {details}")
        if not passed:
            self.errors.append(f"[{category}] {test}: {details}")
            
    def add_warning(self, message: str):
        """Add a warning message"""
        self.warnings.append(f"⚠️  {message}")
        
    def run_diagnostics(self):
        """Run all diagnostic checks"""
        print("🔍 MCP Diagnostics for final-test-mcp")
        print("=" * 60)
        
        # Check if server.py exists
        if not self.server_path.exists():
            self.add_result("PATH", "server.py exists", False, "server.py not found")
            return
        
        self.add_result("PATH", "server.py exists", True)
        
        # Check for wrapper script
        wrapper_path = Path(__file__).parent / "run_server.py"
        if wrapper_path.exists():
            self.add_result("PATH", "run_server.py exists", True)
        else:
            self.add_result("PATH", "run_server.py exists", False, 
                          "Wrapper script missing - MCP may fail to connect")
        
        # Run all checks
        self.check_python_version()
        self.check_dependencies()
        self.check_server_syntax()
        self.check_stdout_pollution()
        self.check_logging_config()
        self.check_import_issues()
        self.test_server_execution()
        self.test_json_rpc_compliance()
        
        # Generate report
        self.generate_report()
        
    def check_python_version(self):
        """Check Python version compatibility"""
        version = sys.version_info
        version_str = f"{version.major}.{version.minor}.{version.micro}"
        
        if version.major >= 3 and version.minor >= 10:
            self.add_result("PYTHON", "Python version", True, f"Python {version_str}")
        else:
            self.add_result("PYTHON", "Python version", False, 
                          f"Python {version_str} (requires 3.10+)")
            
    def check_dependencies(self):
        """Check if all dependencies are installed"""
        try:
            import fastmcp
            self.add_result("DEPS", "FastMCP installed", True, 
                          f"Version: {getattr(fastmcp, '__version__', 'unknown')}")
        except ImportError:
            self.add_result("DEPS", "FastMCP installed", False, 
                          "Install with: pip install fastmcp")
            
        # Check other dependencies from requirements.txt
        if self.requirements_path.exists():
            try:
                with open(self.requirements_path, 'r') as f:
                    lines = f.readlines()
                
                for line in lines:
                    line = line.strip()
                    if not line or line.startswith('#'):
                        continue
                    
                    # Extract package name
                    dep_name = line.split('[')[0].split('>=')[0].split('==')[0].split('<')[0].strip()
                    if dep_name == 'fastmcp':  # Already checked
                        continue
                    
                    try:
                        __import__(dep_name.replace('-', '_'))
                        self.add_result("DEPS", f"Dependency: {dep_name}", True)
                    except ImportError:
                        self.add_result("DEPS", f"Dependency: {dep_name}", False, "Not installed")
            except Exception as e:
                self.add_warning(f"Failed to parse requirements.txt: {e}")
                
    def check_server_syntax(self):
        """Check server.py for syntax errors"""
        try:
            with open(self.server_path, 'r') as f:
                code = f.read()
            
            # Try to compile the code
            compile(code, str(self.server_path), 'exec')
            self.add_result("SYNTAX", "server.py syntax", True)
            
            # Parse AST for deeper analysis
            tree = ast.parse(code)
            self.analyze_ast(tree)
            
        except SyntaxError as e:
            self.add_result("SYNTAX", "server.py syntax", False, f"Line {e.lineno}: {e.msg}")
        except Exception as e:
            self.add_result("SYNTAX", "server.py syntax", False, str(e))
            
    def analyze_ast(self, tree):
        """Analyze AST for common issues"""
        class PrintChecker(ast.NodeVisitor):
            def __init__(self):
                self.print_calls = []
                
            def visit_Call(self, node):
                if isinstance(node.func, ast.Name) and node.func.id == 'print':
                    self.print_calls.append(node.lineno)
                self.generic_visit(node)
                
        checker = PrintChecker()
        checker.visit(tree)
        
        if checker.print_calls:
            self.add_result("STDOUT", "No print() statements", False, 
                          f"Found print() at lines: {', '.join(map(str, checker.print_calls))}")
            self.add_warning("print() statements will break MCP protocol! Use logger.info() instead")
        else:
            self.add_result("STDOUT", "No print() statements", True)
            
    def check_stdout_pollution(self):
        """Check for potential stdout pollution"""
        with open(self.server_path, 'r') as f:
            content = f.read()
            
        # Check for stdout redirect
        if "sys.stdout = sys.stderr" in content:
            self.add_result("STDOUT", "Stdout redirection present", True)
        else:
            self.add_warning("Consider adding stdout redirection during imports")
            
    def check_logging_config(self):
        """Check logging configuration"""
        with open(self.server_path, 'r') as f:
            content = f.read()
            
        if "logging.basicConfig" in content and "stream=sys.stderr" in content:
            self.add_result("LOGGING", "Logging to stderr", True)
        else:
            self.add_result("LOGGING", "Logging to stderr", False, 
                          "Ensure logging is configured to stderr")
            
    def check_import_issues(self):
        """Check for import issues"""
        try:
            # Temporarily redirect stdout
            old_stdout = sys.stdout
            sys.stdout = sys.stderr
            
            # Try to import the server module
            import importlib.util
            spec = importlib.util.spec_from_file_location("test_server", self.server_path)
            if spec and spec.loader:
                module = importlib.util.module_from_spec(spec)
                spec.loader.exec_module(module)
                
            sys.stdout = old_stdout
            self.add_result("IMPORT", "Server imports successfully", True)
            
        except Exception as e:
            sys.stdout = old_stdout
            self.add_result("IMPORT", "Server imports successfully", False, str(e))
            
    def test_server_execution(self):
        """Test direct server execution"""
        try:
            result = subprocess.run(
                [sys.executable, str(self.server_path)],
                capture_output=True,
                text=True,
                timeout=2,
                env={**os.environ, "MCP_TEST_MODE": "1"}
            )
            
            if result.returncode == 0 or "Server started" in result.stderr:
                self.add_result("EXEC", "Server starts without errors", True)
            else:
                self.add_result("EXEC", "Server starts without errors", False, 
                              f"Exit code: {result.returncode}")
                if result.stderr:
                    self.add_warning(f"Stderr: {result.stderr[:200]}")
                    
        except subprocess.TimeoutExpired:
            # Timeout is expected for a running server
            self.add_result("EXEC", "Server starts without errors", True, 
                          "Server running (timeout expected)")
        except Exception as e:
            self.add_result("EXEC", "Server starts without errors", False, str(e))
            
    def test_json_rpc_compliance(self):
        """Test JSON-RPC compliance"""
        try:
            # Create a simple JSON-RPC request
            test_request = {
                "jsonrpc": "2.0",
                "method": "initialize",
                "params": {"capabilities": {}},
                "id": 1
            }
            
            # Run server with input
            proc = subprocess.Popen(
                [sys.executable, str(self.server_path)],
                stdin=subprocess.PIPE,
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
                text=True
            )
            
            # Send request
            if proc.stdin:
                proc.stdin.write(json.dumps(test_request) + '\n')
                proc.stdin.flush()
            
            # Wait briefly for response
            import time
            time.sleep(0.5)
            
            # Terminate
            proc.terminate()
            stdout, stderr = proc.communicate(timeout=1)
            
            # Check if stdout contains valid JSON
            if stdout and stdout.strip():
                try:
                    response = json.loads(stdout.strip().split('\n')[0])
                    if "jsonrpc" in response:
                        self.add_result("JSON-RPC", "Valid JSON-RPC response", True)
                    else:
                        self.add_result("JSON-RPC", "Valid JSON-RPC response", False, 
                                      "Missing jsonrpc field")
                except json.JSONDecodeError:
                    self.add_result("JSON-RPC", "Valid JSON-RPC response", False, 
                                  "Invalid JSON in stdout")
            else:
                self.add_result("JSON-RPC", "Valid JSON-RPC response", False, 
                              "No response received")
                
        except Exception as e:
            self.add_result("JSON-RPC", "Valid JSON-RPC response", False, str(e))
            
    def generate_report(self):
        """Generate final diagnostic report"""
        print("\n📊 DIAGNOSTIC RESULTS")
        print("=" * 60)
        
        for result in self.results:
            print(result)
            
        if self.warnings:
            print("\n⚠️  WARNINGS")
            print("-" * 60)
            for warning in self.warnings:
                print(warning)
                
        if self.errors:
            print("\n❌ FAILURES SUMMARY")
            print("-" * 60)
            for error in self.errors:
                print(f"  • {error}")
                
            print("\n🔧 RECOMMENDED FIXES")
            print("-" * 60)
            
            # Generate specific recommendations
            if any("print()" in e for e in self.errors):
                print("  1. Replace all print() with logger.info()")
                
            if any("FastMCP" in e for e in self.errors):
                print("  2. Install dependencies: pip install -r requirements.txt")
                
            if any("stderr" in e for e in self.errors):
                print("  3. Check logging configuration in server.py")
                
        else:
            print("\n✅ All checks passed! MCP should load successfully.")
            
        print("\n📋 NEXT STEPS")
        print("-" * 60)
        if self.errors:
            print("1. Fix the issues listed above")
            print("2. Re-run: python diagnose.py")
            print("3. Test: python server.py")
            print("4. Add to Claude: claude mcp add final-test-mcp \"$(pwd)/run_server.py\"")
        else:
            print("1. Install if needed: pip install -r requirements.txt")
            print("2. Add to Claude: claude mcp add final-test-mcp \"$(pwd)/run_server.py\"")
            print("3. Exit and restart Claude Code")


def main():
    """Run diagnostics"""
    diagnostics = MCPDiagnostics()
    diagnostics.run_diagnostics()


if __name__ == "__main__":
    main()
