# topological_analysis_tools.py import posixpath import re from typing import Any, Literal from contextlib import asynccontextmanager from collections.abc import AsyncIterator import asyncssh from pydantic import BaseModel, Field # 从 MCP SDK 导入核心组件 [1] from mcp.server.fastmcp import FastMCP, Context from mcp.server.session import ServerSession # 从 lifespan.py 导入共享的配置和数据类 [3] # 注意: 我们不再需要导入 system_lifespan from lifespan import ( SharedAppContext, REMOTE_HOST, REMOTE_USER, PRIVATE_KEY_PATH, INITIAL_WORKING_DIRECTORY, ) # ============================================================================== # 1. 辅助函数 (为实现独立性而复制) # ============================================================================== def shell_quote(arg: str) -> str: """安全地将字符串转义为单个 shell 参数 [5]。""" return "'" + arg.replace("'", "'\"'\"'") + "'" def _safe_join(sandbox_root: str, relative_path: str) -> str: """安全地将用户提供的相对路径连接到沙箱根目录 [5]。""" rel = (relative_path or ".").strip() if rel.startswith("/"): rel = rel.lstrip("/") combined = posixpath.normpath(posixpath.join(sandbox_root, rel)) root_norm = sandbox_root.rstrip("/") if combined != root_norm and not combined.startswith(root_norm + "/"): raise ValueError("路径越界: 仅允许访问沙箱目录内部") if ".." in combined.split("/"): raise ValueError("非法路径: 不允许使用 '..' 跨目录") return combined # ============================================================================== # 2. 独立的生命周期管理器 # ============================================================================== @asynccontextmanager async def analysis_lifespan(_server: FastMCP) -> AsyncIterator[SharedAppContext]: """ 为拓扑分析工具独立管理 SSH 连接的生命周期。 """ conn: asyncssh.SSHClientConnection | None = None print("分析工具: 正在建立独立的 SSH 连接...") try: conn = await asyncssh.connect( REMOTE_HOST, username=REMOTE_USER, client_keys=[PRIVATE_KEY_PATH], ) print(f"分析工具: 独立的 SSH 连接到 {REMOTE_HOST} 成功!") yield SharedAppContext(ssh_connection=conn, sandbox_path=INITIAL_WORKING_DIRECTORY) finally: if conn: conn.close() await conn.wait_closed() print("分析工具: 独立的 SSH 连接已关闭。") # ============================================================================== # 3. 数据模型与解析函数 (与之前相同) # ============================================================================== # (为简洁起见,此处省略了 Pydantic 模型和 _parse_analysis_output 函数的定义, # 请将上一版本中的这部分代码粘贴到这里) VALID_CONFIGS = Literal["O.yaml", "S.yaml", "Cl.yaml", "Br.yaml"] class SubmitTopologicalAnalysisParams(BaseModel): cif_file_path: str = Field(description="远程服务器上 CIF 文件的相对路径。") config_files: list[VALID_CONFIGS] = Field(description="选择一个或多个用于分析的 YAML 配置文件。") output_file_path: str = Field(description="用于保存计算结果的输出文件相对路径。") class AnalysisResult(BaseModel): percolation_diameter_a: float | None = Field(None, description="渗透直径 (Percolation diameter), 单位 Å。") connectivity_distance: float | None = Field(None, description="连通距离 (the minium of d), 单位 Å。") max_node_length_a: float | None = Field(None, description="最大节点长度 (Maximum node length), 单位 Å。") min_cation_distance_a: float | None = Field(None, description="到阳离子的最小距离, 单位 Å。") total_time_s: float | None = Field(None, description="总计算耗时, 单位秒。") long_node_count: int | None = Field(None, description="长节点数量 (Long node number)。") short_node_count: int | None = Field(None, description="短节点数量 (Short node number)。") warnings: list[str] = Field(default_factory=list, description="解析过程中遇到的警告或缺失的关键值信息。") def _parse_analysis_output(log_content: str) -> AnalysisResult: def find_float(pattern: str) -> float | None: match = re.search(pattern, log_content) return float(match.group(1)) if match else None def find_int(pattern: str) -> int | None: match = re.search(pattern, log_content) return int(match.group(1)) if match else None data = { "percolation_diameter_a": find_float(r"Percolation diameter \(A\):\s*([\d\.]+)"), "max_node_length_a": find_float(r"Maximum node length detected:\s*([\d\.]+)\s*A"), "min_cation_distance_a": find_float(r"minimum distance to cations is\s*([\d\.]+)\s*A"), "total_time_s": find_float(r"Total used time:\s*([\d\.]+)"), "long_node_count": find_int(r"Long node number:\s*(\d+)"), "short_node_count": find_int(r"Short node number:\s*(\d+)"), } conn_dist_match = re.search(r"the minium of d\s*([\d\.]+)", log_content, re.MULTILINE) data["connectivity_distance"] = float(conn_dist_match.group(1)) if conn_dist_match else None warnings = [k + " 未找到" for k, v in data.items() if v is None and k not in ["min_cation_distance_a", "long_node_count", "short_node_count"]] return AnalysisResult(**data, warnings=warnings) # ============================================================================== # 4. MCP 服务器工厂函数 # ============================================================================== def create_topological_analysis_mcp() -> FastMCP: """ 供 Starlette 挂载的工厂函数。 这个函数现在创建并使用自己独立的 lifespan 管理器。 """ analysis_mcp = FastMCP( name="Topological Analysis Tools", instructions="用于在远程服务器上提交和分析拓扑计算任务的工具集。", streamable_http_path="/", lifespan=analysis_lifespan, # 关键: 使用本文件内定义的独立 lifespan ) @analysis_mcp.tool() async def submit_topological_analysis( params: SubmitTopologicalAnalysisParams, ctx: Context[ServerSession, SharedAppContext], ) -> dict[str, Any]: """【步骤1/2】异步提交拓扑分析任务。""" # ... (工具的实现逻辑与之前完全相同) try: app_ctx = ctx.request_context.lifespan_context conn = app_ctx.ssh_connection sandbox_root = app_ctx.sandbox_path home_dir = f"/cluster/home/{REMOTE_USER}" tool_dir = f"{home_dir}/tool/Topological_analyse" cif_abs_path = _safe_join(sandbox_root, params.cif_file_path) output_abs_path = _safe_join(sandbox_root, params.output_file_path) config_args = " ".join([shell_quote(f"{tool_dir}/{cfg}") for cfg in params.config_files]) config_flags = f"-i {config_args}" command = f""" nohup sh -c ' source {shell_quote(f"{home_dir}/.bashrc")} && \\ conda activate {shell_quote(f"{home_dir}/anaconda3/envs/zeo")} && \\ python {shell_quote(f"{tool_dir}/analyze_voronoi_nodes.py")} \\ {shell_quote(cif_abs_path)} \\ {config_flags} ' > {shell_quote(output_abs_path)} 2>&1 & """.strip() await ctx.info("提交后台拓扑分析任务...") await conn.run(command, check=True) return {"success": True, "message": "拓扑分析任务已成功提交。", "output_file": params.output_file_path} except Exception as e: msg = f"提交后台任务失败: {e}" await ctx.error(msg) return {"success": False, "error": msg} @analysis_mcp.tool() async def analyze_topological_results( output_file_path: str, ctx: Context[ServerSession, SharedAppContext], ) -> AnalysisResult: """【步骤2/2】读取并分析拓扑分析任务的输出文件。""" # ... (工具的实现逻辑与之前完全相同) try: await ctx.info(f"开始分析结果文件: {output_file_path}") app_ctx = ctx.request_context.lifespan_context conn = app_ctx.ssh_connection sandbox_root = app_ctx.sandbox_path target_path = _safe_join(sandbox_root, output_file_path) async with conn.start_sftp_client() as sftp: async with sftp.open(target_path, "r") as f: log_content = await f.read() if not log_content: raise ValueError("结果文件为空。") analysis_data = _parse_analysis_output(log_content) return analysis_data except FileNotFoundError: msg = f"分析失败: 结果文件 '{output_file_path}' 不存在。" await ctx.error(msg) return AnalysisResult(warnings=[msg]) except Exception as e: msg = f"分析结果时发生错误: {e}" await ctx.error(msg) return AnalysisResult(warnings=[msg]) return analysis_mcp