diff --git a/mcp/system_tools.py b/mcp/system_tools.py index bad854a..a9c6746 100644 --- a/mcp/system_tools.py +++ b/mcp/system_tools.py @@ -1,73 +1,116 @@ # system_tools.py -import os import posixpath import stat as pystat +from shlex import shlex from typing import Any +from contextlib import asynccontextmanager +from collections.abc import AsyncIterator import asyncssh from mcp.server.fastmcp import FastMCP, Context from mcp.server.session import ServerSession -from lifespan import SharedAppContext # 保持你的现有导入 + +# 在 system_tools.py 顶部添加 +def shell_quote(arg: str) -> str: + """ + 安全地把字符串作为单个 shell 参数: + - 外层用单引号包裹 + - 内部的单引号 ' 替换为 '\'' 序列 + 适用于远端 Linux shell 命令拼接 + """ + return "'" + arg.replace("'", "'\"'\"'") + "'" + +# 复用你的配置与数据类 +from lifespan import ( + SharedAppContext, + REMOTE_HOST, + REMOTE_USER, + PRIVATE_KEY_PATH, + INITIAL_WORKING_DIRECTORY, +) + +# —— 1) 定义 FastMCP 的生命周期,在启动时建立 SSH 连接,关闭时断开 —— +@asynccontextmanager +async def system_lifespan(_server: FastMCP) -> AsyncIterator[SharedAppContext]: + """ + FastMCP 生命周期:建立并注入共享的 SSH 连接与沙箱根路径。 + 说明:这是 MCP 自己的生命周期,工具里通过 ctx.request_context.lifespan_context 访问。 + """ + conn: asyncssh.SSHClientConnection | None = None + try: + # 建立 SSH 连接 + conn = await asyncssh.connect( + REMOTE_HOST, + username=REMOTE_USER, + client_keys=[PRIVATE_KEY_PATH], + ) + # 将类型安全的共享上下文注入 MCP 生命周期 + yield SharedAppContext(ssh_connection=conn, sandbox_path=INITIAL_WORKING_DIRECTORY) + finally: + # 关闭连接 + if conn: + conn.close() + await conn.wait_closed() + def create_system_mcp() -> FastMCP: """创建一个包含系统操作工具的 MCP 实例。""" - system_mcp = FastMCP( name="System Tools", instructions="用于在远程服务器上进行基本文件和目录操作的工具集。", - # 将 MCP 实例的 HTTP 路径设置为根,以便挂载在 /system 下 [1] - streamable_http_path="/" + streamable_http_path="/", + lifespan=system_lifespan, # 关键:把生命周期传给 FastMCP [1] ) def _safe_join(sandbox_root: str, relative_path: str) -> str: """ 将用户提供的相对路径映射到沙箱根目录内的规范化绝对路径。 - - 统一使用 POSIX 语义(远端 Linux) + - 统一使用 POSIX 语义(远端 Linux) - 禁止使用以 '/' 开头的绝对路径 - - 禁止 '..' 越界,确保最终路径仍在沙箱内 + - 禁止 '..' 越界,确保最终路径仍在沙箱内 """ rel = (relative_path or ".").strip() - # 禁止绝对路径,转为相对 + # 禁止绝对路径,转为相对 if rel.startswith("/"): rel = rel.lstrip("/") # 规范化拼接 combined = posixpath.normpath(posixpath.join(sandbox_root, rel)) - # 统一尾部斜杠处理,避免边界判断遗漏 + # 统一尾部斜杠处理,避免边界判断遗漏 root_norm = sandbox_root.rstrip("/") # 确保仍在沙箱内 if combined != root_norm and not combined.startswith(root_norm + "/"): - raise ValueError("路径越界:仅允许访问沙箱目录内部") + raise ValueError("路径越界:仅允许访问沙箱目录内部") - # 禁止路径中出现 '..'(进一步加固) + # 禁止路径中出现 '..'(进一步加固) parts = [p for p in combined.split("/") if p] if ".." in parts: - raise ValueError("非法路径:不允许使用 '..' 跨目录") + raise ValueError("非法路径:不允许使用 '..' 跨目录") return combined - # —— 重构后的三个工具:list_files / read_file / write_file —— + # —— 重构后的各工具:统一用类型安全的 ctx 与 app_ctx 访问共享资源 —— + @system_mcp.tool() - async def list_files(ctx: Context[ServerSession, dict], path: str = ".") -> list[dict[str, Any]]: + async def list_files(ctx: Context[ServerSession, SharedAppContext], path: str = ".") -> list[dict[str, Any]]: """ - 列出远程沙箱目录中的文件和子目录(结构化输出)。 + 列出远程沙箱目录中的文件和子目录(结构化输出)。 Returns: list[dict]: [{name, path, is_dir, size, permissions, mtime}] """ try: - shared_ctx: SharedAppContext = ctx.request_context.lifespan_context["shared_context"] - conn = shared_ctx.ssh_connection - sandbox_root = shared_ctx.sandbox_path + app_ctx = ctx.request_context.lifespan_context + conn = app_ctx.ssh_connection + sandbox_root = app_ctx.sandbox_path target = _safe_join(sandbox_root, path) - await ctx.info(f"列出目录:{target}") + await ctx.info(f"列出目录:{target}") items: list[dict[str, Any]] = [] async with conn.start_sftp_client() as sftp: - # 先列出文件名,再逐个 stat 获取属性 names = await sftp.listdir(target) # list[str] for name in names: item_abs = posixpath.join(target, name) @@ -84,62 +127,53 @@ def create_system_mcp() -> FastMCP: items.append({ "name": name, - # 返回相对路径,便于后续工具继续在沙箱内操作 "path": posixpath.join(path.rstrip("/"), name) if path != "/" else name, "is_dir": is_dir, "size": size, - "permissions": perms, # 九进制权限位,例如 0o755(以整型传递) - "mtime": mtime, # 秒 + "permissions": perms, + "mtime": mtime, }) - await ctx.debug(f"目录项数量:{len(items)}") + await ctx.debug(f"目录项数量:{len(items)}") return items except FileNotFoundError: - msg = f"目录不存在或不可访问:{path}" + msg = f"目录不存在或不可访问:{path}" await ctx.error(msg) return [{"error": msg}] except Exception as e: - msg = f"list_files 失败:{e}" + msg = f"list_files 失败:{e}" await ctx.error(msg) return [{"error": msg}] @system_mcp.tool() - async def read_file(ctx: Context[ServerSession, dict], file_path: str, encoding: str = "utf-8") -> str: + async def read_file(ctx: Context[ServerSession, SharedAppContext], file_path: str, encoding: str = "utf-8") -> str: """ 读取远程沙箱内指定文件内容。 - Args: - file_path: 相对于沙箱根的文件路径 - encoding: 文本编码(默认 utf-8) - Returns: - 文件文本内容(字符串) """ try: - shared_ctx: SharedAppContext = ctx.request_context.lifespan_context["shared_context"] - conn = shared_ctx.ssh_connection - sandbox_root = shared_ctx.sandbox_path + app_ctx = ctx.request_context.lifespan_context + conn = app_ctx.ssh_connection + sandbox_root = app_ctx.sandbox_path target = _safe_join(sandbox_root, file_path) - await ctx.info(f"读取文件:{target}") + await ctx.info(f"读取文件:{target}") async with conn.start_sftp_client() as sftp: - # 以二进制读取,按 encoding 解码为文本 async with sftp.open(target, "rb") as f: data = await f.read() try: content = data.decode(encoding) except Exception: - # 如果解码失败,回退为原始字节的 repr content = data.decode(encoding, errors="replace") - return content except FileNotFoundError: - msg = f"读取失败:文件不存在 '{file_path}'" + msg = f"读取失败:文件不存在 '{file_path}'" await ctx.error(msg) return msg except Exception as e: - msg = f"read_file 失败:{e}" + msg = f"read_file 失败:{e}" await ctx.error(msg) return msg @@ -152,14 +186,7 @@ def create_system_mcp() -> FastMCP: create_parents: bool = True, ) -> dict[str, Any]: """ - 写入远程沙箱文件(默认按需创建父目录)。 - Args: - file_path: 相对于沙箱根的文件路径 - content: 要写入的文本内容 - encoding: 写入编码(默认 utf-8) - create_parents: True 则自动创建父目录 - Returns: - dict: { path, bytes_written } + 写入远程沙箱文件(默认按需创建父目录)。 """ try: shared_ctx: SharedAppContext = ctx.request_context.lifespan_context["shared_context"] @@ -167,14 +194,12 @@ def create_system_mcp() -> FastMCP: sandbox_root = shared_ctx.sandbox_path target = _safe_join(sandbox_root, file_path) - await ctx.info(f"写入文件:{target}") + await ctx.info(f"写入文件:{target}") - # 可选:创建父目录 if create_parents: parent = posixpath.dirname(target) if parent and parent != sandbox_root: - # mkdir -p,更快且不易出错 - await conn.run(f"mkdir -p {conn.escape(parent)}", check=True) + await conn.run(f"mkdir -p {shell_quote(parent)}", check=True) data = content.encode(encoding) @@ -182,33 +207,25 @@ def create_system_mcp() -> FastMCP: async with sftp.open(target, "wb") as f: await f.write(data) - await ctx.debug(f"写入完成:{len(data)} 字节") + await ctx.debug(f"写入完成:{len(data)} 字节") return {"path": file_path, "bytes_written": len(data)} - except Exception as e: - msg = f"write_file 失败:{e}" + msg = f"write_file 失败:{e}" await ctx.error(msg) return {"error": msg} - - - @system_mcp.tool() - async def make_dir(ctx: Context[ServerSession, dict], path: str, parents: bool = True) -> str: - """ - 创建目录(默认支持递归创建父级 -p)。 - Args: - path: 相对于沙箱根的目录路径 - parents: True 表示使用 `mkdir -p`,False 则只创建最后一级 - """ + async def make_dir(ctx: Context[ServerSession, SharedAppContext], path: str, parents: bool = True) -> str: try: - shared_ctx: SharedAppContext = ctx.request_context.lifespan_context["shared_context"] - conn = shared_ctx.ssh_connection - sandbox_root = shared_ctx.sandbox_path + app_ctx = ctx.request_context.lifespan_context + conn = app_ctx.ssh_connection + sandbox_root = app_ctx.sandbox_path target = _safe_join(sandbox_root, path) + if parents: - await conn.run(f"mkdir -p {conn.escape(target)}", check=True) + # 单行命令:mkdir -p + await conn.run(f"mkdir -p {shell_quote(target)}", check=True) else: async with conn.start_sftp_client() as sftp: await sftp.mkdir(target) @@ -220,16 +237,14 @@ def create_system_mcp() -> FastMCP: return f"错误: {e}" @system_mcp.tool() - async def delete_file(ctx: Context[ServerSession, dict], file_path: str) -> str: + async def delete_file(ctx: Context[ServerSession, SharedAppContext], file_path: str) -> str: """ - 删除文件(非目录)。 - Args: - file_path: 相对于沙箱根的文件路径 + 删除文件(非目录)。 """ try: - shared_ctx: SharedAppContext = ctx.request_context.lifespan_context["shared_context"] - conn = shared_ctx.ssh_connection - sandbox_root = shared_ctx.sandbox_path + app_ctx = ctx.request_context.lifespan_context + conn = app_ctx.ssh_connection + sandbox_root = app_ctx.sandbox_path target = _safe_join(sandbox_root, file_path) async with conn.start_sftp_client() as sftp: @@ -238,7 +253,7 @@ def create_system_mcp() -> FastMCP: await ctx.info(f"文件已删除: {target}") return f"文件已删除: {file_path}" except FileNotFoundError: - msg = f"删除失败:文件不存在 '{file_path}'" + msg = f"删除失败:文件不存在 '{file_path}'" await ctx.error(msg) return msg except Exception as e: @@ -249,9 +264,6 @@ def create_system_mcp() -> FastMCP: async def delete_dir(ctx: Context[ServerSession, dict], dir_path: str, recursive: bool = False) -> str: """ 删除目录。 - Args: - dir_path: 相对于沙箱根的目录路径 - recursive: True 则递归删除(rm -rf),False 仅删除空目录(rmdir) """ try: shared_ctx: SharedAppContext = ctx.request_context.lifespan_context["shared_context"] @@ -261,7 +273,7 @@ def create_system_mcp() -> FastMCP: target = _safe_join(sandbox_root, dir_path) if recursive: - await conn.run(f"rm -rf {conn.escape(target)}", check=True) + await conn.run(f"rm -rf {shell_quote(target)}", check=True) else: async with conn.start_sftp_client() as sftp: await sftp.rmdir(target) @@ -269,37 +281,34 @@ def create_system_mcp() -> FastMCP: await ctx.info(f"目录已删除: {target}") return f"目录已删除: {dir_path}" except FileNotFoundError: - msg = f"删除失败:目录不存在 '{dir_path}' 或非空" + msg = f"删除失败:目录不存在 '{dir_path}' 或非空" await ctx.error(msg) return msg + except Exception as e: await ctx.error(f"删除目录失败: {e}") return f"错误: {e}" @system_mcp.tool() - async def move_path(ctx: Context[ServerSession, dict], src: str, dst: str, overwrite: bool = True) -> str: - """ - 移动/重命名文件或目录。 - Args: - src: 源路径(相对于沙箱根) - dst: 目标路径(相对于沙箱根) - overwrite: True 使用 mv -f 覆盖同名目标 - """ + async def move_path(ctx: Context[ServerSession, SharedAppContext], src: str, dst: str, + overwrite: bool = True) -> str: try: - shared_ctx: SharedAppContext = ctx.request_context.lifespan_context["shared_context"] - conn = shared_ctx.ssh_connection - sandbox_root = shared_ctx.sandbox_path + app_ctx = ctx.request_context.lifespan_context + conn = app_ctx.ssh_connection + sandbox_root = app_ctx.sandbox_path src_abs = _safe_join(sandbox_root, src) dst_abs = _safe_join(sandbox_root, dst) - flags = "-f" if overwrite else "" - await conn.run(f"mv {flags} {conn.escape(src_abs)} {conn.escape(dst_abs)}", check=True) + flag = "-f" if overwrite else "" + # 单行命令:mv + cmd = f"mv {flag} {shell_quote(src_abs)} {shell_quote(dst_abs)}".strip() + await conn.run(cmd, check=True) await ctx.info(f"已移动: {src_abs} -> {dst_abs}") return f"已移动: {src} -> {dst}" except FileNotFoundError: - msg = f"移动失败:源不存在 '{src}'" + msg = f"移动失败:源不存在 '{src}'" await ctx.error(msg) return msg except Exception as e: @@ -308,24 +317,16 @@ def create_system_mcp() -> FastMCP: @system_mcp.tool() async def copy_path( - ctx: Context[ServerSession, dict], - src: str, - dst: str, - recursive: bool = True, - overwrite: bool = True, + ctx: Context[ServerSession, SharedAppContext], + src: str, + dst: str, + recursive: bool = True, + overwrite: bool = True, ) -> str: - """ - 复制文件或目录。 - Args: - src: 源路径(相对于沙箱根) - dst: 目标路径(相对于沙箱根) - recursive: True 则使用 -r 递归复制目录 - overwrite: True 使用 -f 覆盖同名目标 - """ try: - shared_ctx: SharedAppContext = ctx.request_context.lifespan_context["shared_context"] - conn = shared_ctx.ssh_connection - sandbox_root = shared_ctx.sandbox_path + app_ctx = ctx.request_context.lifespan_context + conn = app_ctx.ssh_connection + sandbox_root = app_ctx.sandbox_path src_abs = _safe_join(sandbox_root, src) dst_abs = _safe_join(sandbox_root, dst) @@ -336,31 +337,30 @@ def create_system_mcp() -> FastMCP: if overwrite: flags.append("-f") - await conn.run(f"cp {' '.join(flags)} {conn.escape(src_abs)} {conn.escape(dst_abs)}", check=True) + # 单行命令:cp + cmd = " ".join(["cp"] + flags + [shell_quote(src_abs), shell_quote(dst_abs)]) + await conn.run(cmd, check=True) await ctx.info(f"已复制: {src_abs} -> {dst_abs}") return f"已复制: {src} -> {dst}" except FileNotFoundError: - msg = f"复制失败:源不存在 '{src}'" + msg = f"复制失败:源不存在 '{src}'" await ctx.error(msg) return msg except Exception as e: await ctx.error(f"复制失败: {e}") return f"错误: {e}" + @system_mcp.tool() - async def exists(ctx: Context[ServerSession, dict], path: str) -> bool: + async def exists(ctx: Context[ServerSession, SharedAppContext], path: str) -> bool: """ - 判断路径(文件/目录)是否存在。 - Args: - path: 相对于沙箱根的路径 - Returns: - True/False + 判断路径(文件/目录)是否存在。 """ try: - shared_ctx: SharedAppContext = ctx.request_context.lifespan_context["shared_context"] - conn = shared_ctx.ssh_connection - sandbox_root = shared_ctx.sandbox_path + app_ctx = ctx.request_context.lifespan_context + conn = app_ctx.ssh_connection + sandbox_root = app_ctx.sandbox_path target = _safe_join(sandbox_root, path) async with conn.start_sftp_client() as sftp: @@ -373,18 +373,14 @@ def create_system_mcp() -> FastMCP: return False @system_mcp.tool() - async def stat_path(ctx: Context[ServerSession, dict], path: str) -> dict: + async def stat_path(ctx: Context[ServerSession, SharedAppContext], path: str) -> dict: """ - 查看远程路径属性(结构化输出)。 - Args: - path: 相对于沙箱根的路径 - Returns: - dict: { path, size, is_dir, permissions, mtime } + 查看远程路径属性(结构化输出)。 """ try: - shared_ctx: SharedAppContext = ctx.request_context.lifespan_context["shared_context"] - conn = shared_ctx.ssh_connection - sandbox_root = shared_ctx.sandbox_path + app_ctx = ctx.request_context.lifespan_context + conn = app_ctx.ssh_connection + sandbox_root = app_ctx.sandbox_path target = _safe_join(sandbox_root, path) async with conn.start_sftp_client() as sftp: @@ -395,7 +391,7 @@ def create_system_mcp() -> FastMCP: "path": target, "size": int(attrs.size or 0), "is_dir": bool(pystat.S_ISDIR(perms)), - "permissions": perms, # 九进制权限位,例:0o755 + "permissions": perms, # 九进制权限位,例:0o755 "mtime": int(attrs.mtime or 0), # 秒 } except FileNotFoundError: @@ -405,4 +401,5 @@ def create_system_mcp() -> FastMCP: except Exception as e: await ctx.error(f"stat 失败: {e}") return {"error": str(e)} - return system_mcp + + return system_mcp \ No newline at end of file