From 1f8667ae51a82cd5954d5af401480b2ca7299170 Mon Sep 17 00:00:00 2001 From: koko <1429659362@qq.com> Date: Thu, 23 Oct 2025 17:08:57 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9E=E4=BA=86zeo++=E7=9A=84?= =?UTF-8?q?=E9=83=A8=E5=88=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mcp/SearchPaperByEmbedding/demo.py | 4 +- mcp/main.py | 11 ++ mcp/paper_search_mcp.py | 172 +++++++++++++++++++++++++++ mcp/topological_analysis_models.py | 185 +++++++++++++++++++++++++++++ 4 files changed, 370 insertions(+), 2 deletions(-) create mode 100644 mcp/paper_search_mcp.py create mode 100644 mcp/topological_analysis_models.py diff --git a/mcp/SearchPaperByEmbedding/demo.py b/mcp/SearchPaperByEmbedding/demo.py index 789ac97..e8753fa 100644 --- a/mcp/SearchPaperByEmbedding/demo.py +++ b/mcp/SearchPaperByEmbedding/demo.py @@ -10,8 +10,8 @@ searcher.compute_embeddings() examples = [ { - "title": "Improving Developer Emotion Classification via LLM-Based Augmentation", - "abstract": "Detecting developer emotion in the informative data stream of technical commit messages..." + "title": "Solid-State battery", + "abstract": "Solid-State battery" }, ] diff --git a/mcp/main.py b/mcp/main.py index 498292d..b882f00 100644 --- a/mcp/main.py +++ b/mcp/main.py @@ -11,10 +11,15 @@ from starlette.routing import Mount from system_tools import create_system_mcp from materialproject_mcp import create_materials_mcp from softBV_remake import create_softbv_mcp +from paper_search_mcp import create_paper_search_mcp +from topological_analysis_models import create_topological_analysis_mcp + # 创建 MCP 实例 system_mcp = create_system_mcp() materials_mcp = create_materials_mcp() softbv_mcp = create_softbv_mcp() +paper_search_mcp = create_paper_search_mcp() +topological_analysis_mcp = create_topological_analysis_mcp() # 在 Starlette 的 lifespan 中启动 MCP 的 session manager @contextlib.asynccontextmanager async def lifespan(app: Starlette): @@ -22,6 +27,8 @@ async def lifespan(app: Starlette): await stack.enter_async_context(system_mcp.session_manager.run()) await stack.enter_async_context(materials_mcp.session_manager.run()) await stack.enter_async_context(softbv_mcp.session_manager.run()) + await stack.enter_async_context(paper_search_mcp.session_manager.run()) + await stack.enter_async_context(topological_analysis_mcp.session_manager.run()) yield # 服务器运行期间 # 退出时自动清理 @@ -32,6 +39,8 @@ app = Starlette( Mount("/system", app=system_mcp.streamable_http_app()), Mount("/materials", app=materials_mcp.streamable_http_app()), Mount("/softBV", app=softbv_mcp.streamable_http_app()), + Mount("/papersearch",app=paper_search_mcp.streamable_http_app()), + Mount("/topologicalAnalysis",app=topological_analysis_mcp.streamable_http_app()), ], ) @@ -41,4 +50,6 @@ app = Starlette( # http://localhost:8000/system # http://localhost:8000/materials # http://localhost:8000/softBV +# http://localhost:8000/papersearch +# http://localhost:8000/topologicalAnalysis # 如果需要浏览器客户端访问(CORS 暴露 Mcp-Session-Id),请参考 README 中的 CORS 配置示例 [1] \ No newline at end of file diff --git a/mcp/paper_search_mcp.py b/mcp/paper_search_mcp.py new file mode 100644 index 0000000..c8fd182 --- /dev/null +++ b/mcp/paper_search_mcp.py @@ -0,0 +1,172 @@ +# paper_search_mcp.py (重构版) + +import os +import asyncio +from dataclasses import dataclass +from typing import Any +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager + +from pydantic import BaseModel, Field +# 假设您的 search.py 提供了 PaperSearcher 类 +from SearchPaperByEmbedding.search import PaperSearcher + +from mcp.server.fastmcp import FastMCP, Context +from mcp.server.session import ServerSession + +# ========= 1. 配置与常量 (修正版) ========= + +# 获取当前脚本文件所在的目录的绝对路径 +_CURRENT_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) + + + +# 基于项目根目录构造资源文件的绝对路径 +DATA_DIR = os.path.join(_CURRENT_SCRIPT_DIR,"SearchPaperByEmbedding") +PAPERS_JSON_PATH = os.path.join(DATA_DIR, "iclr2026_papers.json") + +# 打印路径用于调试,确保它是正确的 +print(f"DEBUG: Trying to load papers from: {PAPERS_JSON_PATH}") + +MODEL_TYPE = os.getenv("EMBEDDING_MODEL_TYPE", "local") + + +# ========= 2. 数据模型 ========= + +class Paper(BaseModel): + """单篇论文的详细信息""" + title: str + authors: list[str] + abstract: str + pdf_url: str | None = None + similarity_score: float = Field(description="与查询的相似度分数,越高越相关") + + +class SearchQuery(BaseModel): + """论文搜索的查询结构""" + title: str + abstract: str | None = Field(None, description="可选的摘要,以提供更丰富的上下文") + + +# ========= 3. Lifespan (资源加载与管理) ========= + +@dataclass +class PaperSearchContext: + """在服务器生命周期内共享的、已初始化的 PaperSearcher 实例""" + searcher: PaperSearcher + + +@asynccontextmanager +async def paper_search_lifespan(_server: FastMCP) -> AsyncIterator[PaperSearchContext]: + """ + FastMCP 生命周期:在服务器启动时初始化 PaperSearcher 并计算 Embeddings。 + 这确保了最耗时的部分只在启动时执行一次。 + """ + print(f"正在使用 '{MODEL_TYPE}' 模型初始化论文搜索引擎...") + + # 使用 to_thread 运行同步的初始化代码 + def _initialize_searcher(): + # 1. 初始化 + searcher = PaperSearcher( + PAPERS_JSON_PATH, + model_type=MODEL_TYPE + ) + # 2. 计算/加载 Embeddings + print("正在计算或加载论文 Embeddings...") + searcher.compute_embeddings() + print("Embeddings 加载完成。") + return searcher + + searcher = await asyncio.to_thread(_initialize_searcher) + print("论文搜索引擎已准备就绪!") + + try: + # 使用 yield 将创建好的共享资源上下文传递给 MCP + yield PaperSearchContext(searcher=searcher) + finally: + print("论文搜索引擎服务关闭。") + + +# ========= 4. MCP 服务器实例 ========= + +mcp = FastMCP( + name="Paper Search", + instructions="通过 embedding 相似度检索 ICLR 2026 论文的工具。", + lifespan=paper_search_lifespan, + streamable_http_path="/", +) + + +# ========= 5. 工具实现 ========= +@mcp.tool() +async def search_papers( + queries: list[SearchQuery], + top_k: int = 3, + ctx: Context[ServerSession, PaperSearchContext] | None = None, +) -> list[Paper]: + """ + 根据给定的一个或多个查询(包含标题和可选的摘要),使用语义相似度搜索论文。 + 返回最相关的 top_k 个结果。 + """ + if ctx is None: + raise ValueError("Context is required for this operation.") + if not queries: + return [] + + app_ctx = ctx.request_context.lifespan_context + searcher = app_ctx.searcher + + # 预处理查询,确保 title 和 abstract 至少有一个存在 + examples_to_search = [] + for q in queries: + query_dict = q.model_dump(exclude_none=True) + if "title" not in query_dict and "abstract" in query_dict: + query_dict["title"] = query_dict["abstract"] + elif "abstract" not in query_dict and "title" in query_dict: + query_dict["abstract"] = query_dict["title"] + examples_to_search.append(query_dict) + + if not examples_to_search: + return [] + + query_display = queries[0].title or queries[0].abstract or "Empty Query" + await ctx.info(f"正在以查询 '{query_display}' 等内容搜索排名前 {top_k} 的论文...") + + # 调用底层的 search 方法 + results = await asyncio.to_thread( + searcher.search, + examples=examples_to_search, + top_k=top_k + ) + + # --- 关键修正:按照 search.py 返回的正确结构进行解析 --- + papers = [] + for res in results: + paper_data = res.get('paper', {}) # 获取嵌套的论文信息字典 + similarity = res.get('similarity', 0.0) # 获取相似度分数 + + papers.append( + Paper( + title=paper_data.get("title", "N/A"), + authors=paper_data.get("authors", []), + abstract=paper_data.get("abstract", ""), + # 注意:原始数据中 pdf url 的键是 'forum_url' + pdf_url=paper_data.get("forum_url"), + similarity_score=similarity # 使用正确的相似度分数 + ) + ) + + await ctx.debug(f"找到 {len(papers)} 篇相关论文。") + return papers + + +# ========= 6. 工厂函数 (用于 Starlette 集成) ========= + +def create_paper_search_mcp() -> FastMCP: + """供 Starlette 挂载的工厂函数。""" + return mcp + + +if __name__ == "__main__": + # 如果直接运行此文件,则启动一个独立的 MCP 服务器 + mcp.run(transport="streamable-http") \ No newline at end of file diff --git a/mcp/topological_analysis_models.py b/mcp/topological_analysis_models.py new file mode 100644 index 0000000..8c89a86 --- /dev/null +++ b/mcp/topological_analysis_models.py @@ -0,0 +1,185 @@ +# topological_analysis_tools.py + +import posixpath +import re +from typing import Any, Literal +from contextlib import asynccontextmanager +from collections.abc import AsyncIterator + +import asyncssh +from pydantic import BaseModel, Field + +# 从 MCP SDK 导入核心组件 [1] +from mcp.server.fastmcp import FastMCP, Context +from mcp.server.session import ServerSession + +# 从 lifespan.py 导入共享的配置和数据类 [3] +# 注意: 我们不再需要导入 system_lifespan +from lifespan import ( + SharedAppContext, + REMOTE_HOST, + REMOTE_USER, + PRIVATE_KEY_PATH, + INITIAL_WORKING_DIRECTORY, +) + +# ============================================================================== +# 1. 辅助函数 (为实现独立性而复制) +# ============================================================================== + +def shell_quote(arg: str) -> str: + """安全地将字符串转义为单个 shell 参数 [5]。""" + return "'" + arg.replace("'", "'\"'\"'") + "'" + +def _safe_join(sandbox_root: str, relative_path: str) -> str: + """安全地将用户提供的相对路径连接到沙箱根目录 [5]。""" + rel = (relative_path or ".").strip() + if rel.startswith("/"): + rel = rel.lstrip("/") + combined = posixpath.normpath(posixpath.join(sandbox_root, rel)) + root_norm = sandbox_root.rstrip("/") + if combined != root_norm and not combined.startswith(root_norm + "/"): + raise ValueError("路径越界: 仅允许访问沙箱目录内部") + if ".." in combined.split("/"): + raise ValueError("非法路径: 不允许使用 '..' 跨目录") + return combined + +# ============================================================================== +# 2. 独立的生命周期管理器 +# ============================================================================== + +@asynccontextmanager +async def analysis_lifespan(_server: FastMCP) -> AsyncIterator[SharedAppContext]: + """ + 为拓扑分析工具独立管理 SSH 连接的生命周期。 + """ + conn: asyncssh.SSHClientConnection | None = None + print("分析工具: 正在建立独立的 SSH 连接...") + try: + conn = await asyncssh.connect( + REMOTE_HOST, + username=REMOTE_USER, + client_keys=[PRIVATE_KEY_PATH], + ) + print(f"分析工具: 独立的 SSH 连接到 {REMOTE_HOST} 成功!") + yield SharedAppContext(ssh_connection=conn, sandbox_path=INITIAL_WORKING_DIRECTORY) + finally: + if conn: + conn.close() + await conn.wait_closed() + print("分析工具: 独立的 SSH 连接已关闭。") + + +# ============================================================================== +# 3. 数据模型与解析函数 (与之前相同) +# ============================================================================== +# (为简洁起见,此处省略了 Pydantic 模型和 _parse_analysis_output 函数的定义, +# 请将上一版本中的这部分代码粘贴到这里) +VALID_CONFIGS = Literal["O.yaml", "S.yaml", "Cl.yaml", "Br.yaml"] +class SubmitTopologicalAnalysisParams(BaseModel): + cif_file_path: str = Field(description="远程服务器上 CIF 文件的相对路径。") + config_files: list[VALID_CONFIGS] = Field(description="选择一个或多个用于分析的 YAML 配置文件。") + output_file_path: str = Field(description="用于保存计算结果的输出文件相对路径。") +class AnalysisResult(BaseModel): + percolation_diameter_a: float | None = Field(None, description="渗透直径 (Percolation diameter), 单位 Å。") + connectivity_distance: float | None = Field(None, description="连通距离 (the minium of d), 单位 Å。") + max_node_length_a: float | None = Field(None, description="最大节点长度 (Maximum node length), 单位 Å。") + min_cation_distance_a: float | None = Field(None, description="到阳离子的最小距离, 单位 Å。") + total_time_s: float | None = Field(None, description="总计算耗时, 单位秒。") + long_node_count: int | None = Field(None, description="长节点数量 (Long node number)。") + short_node_count: int | None = Field(None, description="短节点数量 (Short node number)。") + warnings: list[str] = Field(default_factory=list, description="解析过程中遇到的警告或缺失的关键值信息。") +def _parse_analysis_output(log_content: str) -> AnalysisResult: + def find_float(pattern: str) -> float | None: + match = re.search(pattern, log_content) + return float(match.group(1)) if match else None + def find_int(pattern: str) -> int | None: + match = re.search(pattern, log_content) + return int(match.group(1)) if match else None + data = { "percolation_diameter_a": find_float(r"Percolation diameter \(A\):\s*([\d\.]+)"), "max_node_length_a": find_float(r"Maximum node length detected:\s*([\d\.]+)\s*A"), "min_cation_distance_a": find_float(r"minimum distance to cations is\s*([\d\.]+)\s*A"), "total_time_s": find_float(r"Total used time:\s*([\d\.]+)"), "long_node_count": find_int(r"Long node number:\s*(\d+)"), "short_node_count": find_int(r"Short node number:\s*(\d+)"), } + conn_dist_match = re.search(r"the minium of d\s*([\d\.]+)", log_content, re.MULTILINE) + data["connectivity_distance"] = float(conn_dist_match.group(1)) if conn_dist_match else None + warnings = [k + " 未找到" for k, v in data.items() if v is None and k not in ["min_cation_distance_a", "long_node_count", "short_node_count"]] + return AnalysisResult(**data, warnings=warnings) + + +# ============================================================================== +# 4. MCP 服务器工厂函数 +# ============================================================================== + +def create_topological_analysis_mcp() -> FastMCP: + """ + 供 Starlette 挂载的工厂函数。 + 这个函数现在创建并使用自己独立的 lifespan 管理器。 + """ + analysis_mcp = FastMCP( + name="Topological Analysis Tools", + instructions="用于在远程服务器上提交和分析拓扑计算任务的工具集。", + streamable_http_path="/", + lifespan=analysis_lifespan, # 关键: 使用本文件内定义的独立 lifespan + ) + + @analysis_mcp.tool() + async def submit_topological_analysis( + params: SubmitTopologicalAnalysisParams, + ctx: Context[ServerSession, SharedAppContext], + ) -> dict[str, Any]: + """【步骤1/2】异步提交拓扑分析任务。""" + # ... (工具的实现逻辑与之前完全相同) + try: + app_ctx = ctx.request_context.lifespan_context + conn = app_ctx.ssh_connection + sandbox_root = app_ctx.sandbox_path + home_dir = f"/cluster/home/{REMOTE_USER}" + tool_dir = f"{home_dir}/tool/Topological_analyse" + cif_abs_path = _safe_join(sandbox_root, params.cif_file_path) + output_abs_path = _safe_join(sandbox_root, params.output_file_path) + config_args = " ".join([shell_quote(f"{tool_dir}/{cfg}") for cfg in params.config_files]) + config_flags = f"-i {config_args}" + command = f""" + nohup sh -c ' + source {shell_quote(f"{home_dir}/.bashrc")} && \\ + conda activate {shell_quote(f"{home_dir}/anaconda3/envs/zeo")} && \\ + python {shell_quote(f"{tool_dir}/analyze_voronoi_nodes.py")} \\ + {shell_quote(cif_abs_path)} \\ + {config_flags} + ' > {shell_quote(output_abs_path)} 2>&1 & + """.strip() + await ctx.info("提交后台拓扑分析任务...") + await conn.run(command, check=True) + return {"success": True, "message": "拓扑分析任务已成功提交。", "output_file": params.output_file_path} + except Exception as e: + msg = f"提交后台任务失败: {e}" + await ctx.error(msg) + return {"success": False, "error": msg} + + @analysis_mcp.tool() + async def analyze_topological_results( + output_file_path: str, + ctx: Context[ServerSession, SharedAppContext], + ) -> AnalysisResult: + """【步骤2/2】读取并分析拓扑分析任务的输出文件。""" + # ... (工具的实现逻辑与之前完全相同) + try: + await ctx.info(f"开始分析结果文件: {output_file_path}") + app_ctx = ctx.request_context.lifespan_context + conn = app_ctx.ssh_connection + sandbox_root = app_ctx.sandbox_path + target_path = _safe_join(sandbox_root, output_file_path) + async with conn.start_sftp_client() as sftp: + async with sftp.open(target_path, "r") as f: + log_content = await f.read() + if not log_content: + raise ValueError("结果文件为空。") + analysis_data = _parse_analysis_output(log_content) + return analysis_data + except FileNotFoundError: + msg = f"分析失败: 结果文件 '{output_file_path}' 不存在。" + await ctx.error(msg) + return AnalysisResult(warnings=[msg]) + except Exception as e: + msg = f"分析结果时发生错误: {e}" + await ctx.error(msg) + return AnalysisResult(warnings=[msg]) + + return analysis_mcp \ No newline at end of file