新增了zeo++的部分

This commit is contained in:
2025-10-23 17:08:57 +08:00
parent c0b2ec5983
commit 1f8667ae51
4 changed files with 370 additions and 2 deletions

172
mcp/paper_search_mcp.py Normal file
View File

@@ -0,0 +1,172 @@
# paper_search_mcp.py (重构版)
import os
import asyncio
from dataclasses import dataclass
from typing import Any
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from pydantic import BaseModel, Field
# 假设您的 search.py 提供了 PaperSearcher 类
from SearchPaperByEmbedding.search import PaperSearcher
from mcp.server.fastmcp import FastMCP, Context
from mcp.server.session import ServerSession
# ========= 1. 配置与常量 (修正版) =========
# 获取当前脚本文件所在的目录的绝对路径
_CURRENT_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
# 基于项目根目录构造资源文件的绝对路径
DATA_DIR = os.path.join(_CURRENT_SCRIPT_DIR,"SearchPaperByEmbedding")
PAPERS_JSON_PATH = os.path.join(DATA_DIR, "iclr2026_papers.json")
# 打印路径用于调试,确保它是正确的
print(f"DEBUG: Trying to load papers from: {PAPERS_JSON_PATH}")
MODEL_TYPE = os.getenv("EMBEDDING_MODEL_TYPE", "local")
# ========= 2. 数据模型 =========
class Paper(BaseModel):
"""单篇论文的详细信息"""
title: str
authors: list[str]
abstract: str
pdf_url: str | None = None
similarity_score: float = Field(description="与查询的相似度分数,越高越相关")
class SearchQuery(BaseModel):
"""论文搜索的查询结构"""
title: str
abstract: str | None = Field(None, description="可选的摘要,以提供更丰富的上下文")
# ========= 3. Lifespan (资源加载与管理) =========
@dataclass
class PaperSearchContext:
"""在服务器生命周期内共享的、已初始化的 PaperSearcher 实例"""
searcher: PaperSearcher
@asynccontextmanager
async def paper_search_lifespan(_server: FastMCP) -> AsyncIterator[PaperSearchContext]:
"""
FastMCP 生命周期:在服务器启动时初始化 PaperSearcher 并计算 Embeddings。
这确保了最耗时的部分只在启动时执行一次。
"""
print(f"正在使用 '{MODEL_TYPE}' 模型初始化论文搜索引擎...")
# 使用 to_thread 运行同步的初始化代码
def _initialize_searcher():
# 1. 初始化
searcher = PaperSearcher(
PAPERS_JSON_PATH,
model_type=MODEL_TYPE
)
# 2. 计算/加载 Embeddings
print("正在计算或加载论文 Embeddings...")
searcher.compute_embeddings()
print("Embeddings 加载完成。")
return searcher
searcher = await asyncio.to_thread(_initialize_searcher)
print("论文搜索引擎已准备就绪!")
try:
# 使用 yield 将创建好的共享资源上下文传递给 MCP
yield PaperSearchContext(searcher=searcher)
finally:
print("论文搜索引擎服务关闭。")
# ========= 4. MCP 服务器实例 =========
mcp = FastMCP(
name="Paper Search",
instructions="通过 embedding 相似度检索 ICLR 2026 论文的工具。",
lifespan=paper_search_lifespan,
streamable_http_path="/",
)
# ========= 5. 工具实现 =========
@mcp.tool()
async def search_papers(
queries: list[SearchQuery],
top_k: int = 3,
ctx: Context[ServerSession, PaperSearchContext] | None = None,
) -> list[Paper]:
"""
根据给定的一个或多个查询(包含标题和可选的摘要),使用语义相似度搜索论文。
返回最相关的 top_k 个结果。
"""
if ctx is None:
raise ValueError("Context is required for this operation.")
if not queries:
return []
app_ctx = ctx.request_context.lifespan_context
searcher = app_ctx.searcher
# 预处理查询,确保 title 和 abstract 至少有一个存在
examples_to_search = []
for q in queries:
query_dict = q.model_dump(exclude_none=True)
if "title" not in query_dict and "abstract" in query_dict:
query_dict["title"] = query_dict["abstract"]
elif "abstract" not in query_dict and "title" in query_dict:
query_dict["abstract"] = query_dict["title"]
examples_to_search.append(query_dict)
if not examples_to_search:
return []
query_display = queries[0].title or queries[0].abstract or "Empty Query"
await ctx.info(f"正在以查询 '{query_display}' 等内容搜索排名前 {top_k} 的论文...")
# 调用底层的 search 方法
results = await asyncio.to_thread(
searcher.search,
examples=examples_to_search,
top_k=top_k
)
# --- 关键修正:按照 search.py 返回的正确结构进行解析 ---
papers = []
for res in results:
paper_data = res.get('paper', {}) # 获取嵌套的论文信息字典
similarity = res.get('similarity', 0.0) # 获取相似度分数
papers.append(
Paper(
title=paper_data.get("title", "N/A"),
authors=paper_data.get("authors", []),
abstract=paper_data.get("abstract", ""),
# 注意:原始数据中 pdf url 的键是 'forum_url'
pdf_url=paper_data.get("forum_url"),
similarity_score=similarity # 使用正确的相似度分数
)
)
await ctx.debug(f"找到 {len(papers)} 篇相关论文。")
return papers
# ========= 6. 工厂函数 (用于 Starlette 集成) =========
def create_paper_search_mcp() -> FastMCP:
"""供 Starlette 挂载的工厂函数。"""
return mcp
if __name__ == "__main__":
# 如果直接运行此文件,则启动一个独立的 MCP 服务器
mcp.run(transport="streamable-http")