import os
import pymysql
from mcp.server.fastmcp import FastMCP

mcp = FastMCP("keevor-mysql-mcp-server")


def get_connection():
    """获取数据库连接"""
    return pymysql.connect(
        host=os.getenv("DB_HOST", "localhost"),
        port=int(os.getenv("DB_PORT", 3306)),
        user=os.getenv("DB_USER", "root"),
        password=os.getenv("DB_PASSWORD", ""),
        database=os.getenv("DB_NAME", ""),
        charset="utf8mb4",
        cursorclass=pymysql.cursors.DictCursor,
    )


@mcp.tool()
def list_tables() -> str:
    """查询数据库中的所有表"""
    try:
        conn = get_connection()
        with conn.cursor() as cursor:
            cursor.execute("SHOW TABLES")
            tables = [list(row.values())[0] for row in cursor.fetchall()]
        conn.close()
        return "\n".join(tables) if tables else "没有找到任何表"
    except Exception as e:
        return f"错误: {str(e)}"


@mcp.tool()
def describe_table(table_name: str) -> str:
    """查询指定表的结构

    Args:
        table_name: 要查询结构的表名
    """
    try:
        conn = get_connection()
        with conn.cursor() as cursor:
            cursor.execute(f"DESCRIBE `{table_name}`")
            columns = cursor.fetchall()
        conn.close()

        if not columns:
            return f"表 {table_name} 不存在或没有列"

        result = f"表 {table_name} 的结构:\n"
        result += "-" * 60 + "\n"
        for col in columns:
            result += f"字段: {col['Field']}, 类型: {col['Type']}, 允许空: {col['Null']}, 键: {col['Key']}, 默认值: {col['Default']}\n"
        return result
    except Exception as e:
        return f"错误: {str(e)}"


@mcp.tool()
def execute_sql(sql: str) -> str:
    """执行SQL语句

    Args:
        sql: 要执行的SQL语句
    """
    try:
        conn = get_connection()
        with conn.cursor() as cursor:
            cursor.execute(sql)

            if sql.strip().upper().startswith("SELECT"):
                rows = cursor.fetchall()
                if not rows:
                    return "查询结果为空"
                return str(rows)
            else:
                conn.commit()
                return f"执行成功，影响行数: {cursor.rowcount}"
        conn.close()
    except Exception as e:
        return f"错误: {str(e)}"


def main():
    mcp.run()


if __name__ == "__main__":
    main()
