# paper_search_mcp.py (重构版) import os import asyncio from dataclasses import dataclass from typing import Any from collections.abc import AsyncIterator from contextlib import asynccontextmanager from pydantic import BaseModel, Field # 假设您的 search.py 提供了 PaperSearcher 类 from SearchPaperByEmbedding.search import PaperSearcher from mcp.server.fastmcp import FastMCP, Context from mcp.server.session import ServerSession # ========= 1. 配置与常量 (修正版) ========= # 获取当前脚本文件所在的目录的绝对路径 _CURRENT_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) # 基于项目根目录构造资源文件的绝对路径 DATA_DIR = os.path.join(_CURRENT_SCRIPT_DIR,"SearchPaperByEmbedding") PAPERS_JSON_PATH = os.path.join(DATA_DIR, "iclr2026_papers.json") # 打印路径用于调试,确保它是正确的 print(f"DEBUG: Trying to load papers from: {PAPERS_JSON_PATH}") MODEL_TYPE = os.getenv("EMBEDDING_MODEL_TYPE", "local") # ========= 2. 数据模型 ========= class Paper(BaseModel): """单篇论文的详细信息""" title: str authors: list[str] abstract: str pdf_url: str | None = None similarity_score: float = Field(description="与查询的相似度分数,越高越相关") class SearchQuery(BaseModel): """论文搜索的查询结构""" title: str abstract: str | None = Field(None, description="可选的摘要,以提供更丰富的上下文") # ========= 3. Lifespan (资源加载与管理) ========= @dataclass class PaperSearchContext: """在服务器生命周期内共享的、已初始化的 PaperSearcher 实例""" searcher: PaperSearcher @asynccontextmanager async def paper_search_lifespan(_server: FastMCP) -> AsyncIterator[PaperSearchContext]: """ FastMCP 生命周期:在服务器启动时初始化 PaperSearcher 并计算 Embeddings。 这确保了最耗时的部分只在启动时执行一次。 """ print(f"正在使用 '{MODEL_TYPE}' 模型初始化论文搜索引擎...") # 使用 to_thread 运行同步的初始化代码 def _initialize_searcher(): # 1. 初始化 searcher = PaperSearcher( PAPERS_JSON_PATH, model_type=MODEL_TYPE ) # 2. 计算/加载 Embeddings print("正在计算或加载论文 Embeddings...") searcher.compute_embeddings() print("Embeddings 加载完成。") return searcher searcher = await asyncio.to_thread(_initialize_searcher) print("论文搜索引擎已准备就绪!") try: # 使用 yield 将创建好的共享资源上下文传递给 MCP yield PaperSearchContext(searcher=searcher) finally: print("论文搜索引擎服务关闭。") # ========= 4. MCP 服务器实例 ========= mcp = FastMCP( name="Paper Search", instructions="通过 embedding 相似度检索 ICLR 2026 论文的工具。", lifespan=paper_search_lifespan, streamable_http_path="/", ) # ========= 5. 工具实现 ========= @mcp.tool() async def search_papers( queries: list[SearchQuery], top_k: int = 3, ctx: Context[ServerSession, PaperSearchContext] | None = None, ) -> list[Paper]: """ 根据给定的一个或多个查询(包含标题和可选的摘要),使用语义相似度搜索论文。 返回最相关的 top_k 个结果。 """ if ctx is None: raise ValueError("Context is required for this operation.") if not queries: return [] app_ctx = ctx.request_context.lifespan_context searcher = app_ctx.searcher # 预处理查询,确保 title 和 abstract 至少有一个存在 examples_to_search = [] for q in queries: query_dict = q.model_dump(exclude_none=True) if "title" not in query_dict and "abstract" in query_dict: query_dict["title"] = query_dict["abstract"] elif "abstract" not in query_dict and "title" in query_dict: query_dict["abstract"] = query_dict["title"] examples_to_search.append(query_dict) if not examples_to_search: return [] query_display = queries[0].title or queries[0].abstract or "Empty Query" await ctx.info(f"正在以查询 '{query_display}' 等内容搜索排名前 {top_k} 的论文...") # 调用底层的 search 方法 results = await asyncio.to_thread( searcher.search, examples=examples_to_search, top_k=top_k ) # --- 关键修正:按照 search.py 返回的正确结构进行解析 --- papers = [] for res in results: paper_data = res.get('paper', {}) # 获取嵌套的论文信息字典 similarity = res.get('similarity', 0.0) # 获取相似度分数 papers.append( Paper( title=paper_data.get("title", "N/A"), authors=paper_data.get("authors", []), abstract=paper_data.get("abstract", ""), # 注意:原始数据中 pdf url 的键是 'forum_url' pdf_url=paper_data.get("forum_url"), similarity_score=similarity # 使用正确的相似度分数 ) ) await ctx.debug(f"找到 {len(papers)} 篇相关论文。") return papers # ========= 6. 工厂函数 (用于 Starlette 集成) ========= def create_paper_search_mcp() -> FastMCP: """供 Starlette 挂载的工厂函数。""" return mcp if __name__ == "__main__": # 如果直接运行此文件,则启动一个独立的 MCP 服务器 mcp.run(transport="streamable-http")