#!/usr/bin/env python
# encoding: utf-8
"""
@file: mcp_client.py
@time: 2025/7/10 18:34
@project: mcp-lite-dev
@desc: 6.3.3 构建MCP Client
"""

import asyncio
import json
import os
from contextlib import AsyncExitStack
from typing import Optional

from dotenv import load_dotenv, find_dotenv
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from openai import OpenAI

# 加载 .env 文件，确保 API Key 受到保护
load_dotenv(find_dotenv(), override=True)


class MCPClient:
    def __init__(self):
        """初始化 MCP 客户端"""
        self.write = None
        self.stdio = None
        self.exit_stack = AsyncExitStack()
        self.api_key = os.getenv("API_KEY")  # 读取 API Key
        self.base_url = os.getenv("BASE_URL")  # 读取 BASE URL
        self.model = os.getenv("MODEL")  # 读取 model
        if not self.api_key:
            raise ValueError("❌ 未找到 OpenAI API Key，请在 .env 文件中设置 OPENAI_API_KEY")
        self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)  # 创建OpenAI client
        self.session: Optional[ClientSession] = None
        self.exit_stack = AsyncExitStack()

    async def connect_to_server(self, server_script_path: str):
        """连接到 MCP 服务器并列出可用工具"""
        is_python = server_script_path.endswith('.py')
        is_js = server_script_path.endswith('.js')
        if not (is_python or is_js):
            raise ValueError("服务器脚本必须是 .py 或 .js 文件")
        command = "python" if is_python else "node"
        server_params = StdioServerParameters(
            command=command,
            args=[server_script_path],
            env=None
        )
        # 启动 MCP 服务器并建立通信
        stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
        self.stdio, self.write = stdio_transport
        self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write))
        await self.session.initialize()
        # 列出 MCP 服务器上的工具
        response = await self.session.list_tools()
        tools = response.tools
        print("\n已连接到服务器，支持以下工具:", [tool.name for tool in tools])

    async def process_query(self, query: str) -> str:
        """
        使用大模型处理查询并调用可用的 MCP 工具 (Function Calling)
        """
        messages = [{"role": "user", "content": query}]

        response = await self.session.list_tools()
        available_tools = [{
            "type": "function",
            "function": {
                "name": tool.name,
                "description": tool.description,
                "input_schema": tool.inputSchema
            }
        } for tool in response.tools]

        response = self.client.chat.completions.create(
            model=self.model,
            messages=messages,
            tools=available_tools
        )

        # 处理返回的内容
        content = response.choices[0]
        if content.finish_reason == "tool_calls":
            # 如果需要使用工具，就解析工具
            tool_call = content.message.tool_calls[0]
            tool_name = tool_call.function.name
            tool_args = json.loads(tool_call.function.arguments)

            # 执行工具
            try:
                result = await self.session.call_tool(tool_name, tool_args)
                print(f"\n\n[Calling tool {tool_name} with args {tool_args}]\n\n")
                
                # 检查工具调用结果
                if not result.content or len(result.content) == 0:
                    raise ValueError("工具调用返回空结果")
                
                tool_result_text = result.content[0].text if hasattr(result.content[0], 'text') else str(result.content[0])
                
            except Exception as e:
                print(f"\n⚠ 工具调用失败: {e}")
                tool_result_text = f"工具调用失败: {str(e)}"
            
            # 将模型返回的调用哪个工具数据和工具执行完成后的数据都存入messages中
            messages.append(content.message.model_dump())
            messages.append({
                "role": "tool",
                "content": tool_result_text,
                "tool_call_id": tool_call.id,
            })

            # 将上面的结果再返回给大模型用于生产最终的结果
            # 注意：第二次请求时不传递 tools 参数，避免 LLM 再次调用工具
            response = self.client.chat.completions.create(
                model=self.model,
                messages=messages,
                tools=None,  # 明确设置为 None，防止再次调用工具
            )
            
            # 检查响应，确保不是工具调用
            final_content = response.choices[0]
            if final_content.finish_reason == "tool_calls":
                # 如果 LLM 仍然试图调用工具，直接返回工具调用的结果
                print("\n⚠ 警告: LLM 试图再次调用工具，直接返回工具结果")
                return tool_result_text
            
            return final_content.message.content

        return content.message.content

    async def chat_loop(self):
        """运行交互式聊天循环"""
        print("\n🤖 MCP 客户端已启动！输入 'quit' 退出")
        while True:
            try:
                query = input("\n你: ").strip()
                if query.lower() == 'quit':
                    break

                response = await self.process_query(query)  # 发送用户输入到 OpenAI API
                print(f"\n🤖 DeepSeek: {response}")
            except Exception as e:
                error_msg = str(e)
                print(f"\n⚠ 发生错误: {error_msg}")
                
                # 针对 403 错误提供更明确的提示
                if "403" in error_msg or "paid balance" in error_msg.lower() or "insufficient" in error_msg.lower():
                    print("\n" + "="*60)
                    print("💡 解决方案：")
                    print("   当前模型需要付费，请尝试以下方法：")
                    print("   1. 修改 .env 文件中的 MODEL 配置")
                    print("   2. 尝试使用免费模型，例如：")
                    print("      MODEL=deepseek-ai/DeepSeek-V2.5")
                    print("      或")
                    print("      MODEL=deepseek-ai/DeepSeek-Chat")
                    print("   3. 保存文件后重新运行程序")
                    print("="*60)

    async def cleanup(self):
        """清理资源"""
        await self.exit_stack.aclose()


async def main():
    if len(sys.argv) < 2:
        print("Usage: python client.py <path_to_server_script>")
        sys.exit(1)
    client = MCPClient()
    try:
        await client.connect_to_server(sys.argv[1])
        await client.chat_loop()
    finally:
        await client.cleanup()


if __name__ == "__main__":
    import sys

    asyncio.run(main())
