From b11bf2417b2b9f19a7f5938aff0e1be020ab7b56 Mon Sep 17 00:00:00 2001 From: koko <1429659362@qq.com> Date: Wed, 15 Oct 2025 14:43:57 +0800 Subject: [PATCH] mcp-gen-material --- mcp/main.py | 29 ++- mcp/materialproject_mcp.py | 513 +++++++++++++++++++++++++++++++++++++ 2 files changed, 533 insertions(+), 9 deletions(-) create mode 100644 mcp/materialproject_mcp.py diff --git a/mcp/main.py b/mcp/main.py index 7321ad2..1360784 100644 --- a/mcp/main.py +++ b/mcp/main.py @@ -1,31 +1,42 @@ # main.py +# 将 Materials CIF MCP 与 System Tools MCP 一起挂载到 Starlette。 +# 关键点: +# - 在 lifespan 中启动每个 MCP 的 session_manager.run()(参考 SDK README 的 Starlette 挂载示例与 streamable_http_app 用法 [1]) +# - 通过 Mount 指定各自的子路径(如 /system 与 /materials) + import contextlib from starlette.applications import Starlette from starlette.routing import Mount -from test_tools import create_test_mcp -from system_tools import create_system_mcp # 如果暂时不用可先不挂 +from system_tools import create_system_mcp +from materialproject_mcp import create_materials_mcp -# 先创建 MCP 实例 -test_mcp = create_test_mcp() +# 创建 MCP 实例 system_mcp = create_system_mcp() +materials_mcp = create_materials_mcp() -# 关键:在 Starlette 的 lifespan 中启动 MCP 的 session manager +# 在 Starlette 的 lifespan 中启动 MCP 的 session manager @contextlib.asynccontextmanager async def lifespan(app: Starlette): async with contextlib.AsyncExitStack() as stack: - # await stack.enter_async_context(test_mcp.session_manager.run()) await stack.enter_async_context(system_mcp.session_manager.run()) + await stack.enter_async_context(materials_mcp.session_manager.run()) yield # 服务器运行期间 # 退出时自动清理 +# 挂载两个 MCP 的 Streamable HTTP App app = Starlette( lifespan=lifespan, routes=[ - # Mount("/test", app=test_mcp.streamable_http_app()), Mount("/system", app=system_mcp.streamable_http_app()), + Mount("/materials", app=materials_mcp.streamable_http_app()), ], ) -# 启动代码为uvicorn main:app --host 0.0.0.0 --port 8000 -# url为http://localhost:8000/system \ No newline at end of file +# 启动命令(终端执行): +# uvicorn main:app --host 0.0.0.0 --port 8000 +# 访问: +# http://localhost:8000/system +# http://localhost:8000/materials +# +# 如果需要浏览器客户端访问(CORS 暴露 Mcp-Session-Id),请参考 README 中的 CORS 配置示例 [1] \ No newline at end of file diff --git a/mcp/materialproject_mcp.py b/mcp/materialproject_mcp.py new file mode 100644 index 0000000..9ee18aa --- /dev/null +++ b/mcp/materialproject_mcp.py @@ -0,0 +1,513 @@ +# materials_mp_cif_mcp.py +# +# 说明: +# - 这是一个基于 FastMCP 的“Materials Project CIF”服务器,提供: +# 1) search_cifs:使用 Materials Project 的 mp-api 进行检索(两步:summary 过滤 + materials 获取结构) +# 2) get_cif:按材料 ID 获取结构并导出 CIF 文本(pymatgen) +# 3) filter_cifs:在本地对候选进行硬筛选 +# 4) download_cif_bulk:批量下载 CIF,带并发与进度 +# 5) 资源:cif://mp/{id} 直接读取 CIF 文本 +# +# - 依赖:pip install "mcp[cli]" mp_api pymatgen +# - API Key: +# 默认从环境变量 MP_API_KEY 读取;若未设置则回退为你提供的 Key(仅用于演示,不建议在生产中硬编码)[3] +# - MCP/Context/Lifespan 的使用方法参考 SDK 文档(工具、结构化输出、进度等)[1] +# +# 引用: +# - MPRester/MP_API_KEY 的用法与推荐的会话管理方式见 Materials Project 文档 [3] +# - summary.search 与 materials.search 的查询与 fields 限制示例见文档 [4] +# - FastMCP 的工具、资源、Context、结构化输出见 README 示例 [1] + +import os +import asyncio +from dataclasses import dataclass +from typing import Any, Literal +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +import re + +from pydantic import BaseModel, Field +from mp_api.client import MPRester # [3] +from pymatgen.core import Structure # 用于转 CIF +from mcp.server.fastmcp import FastMCP, Context # [1] +from mcp.server.session import ServerSession # [1] + +# ========= 配置 ========= + +# 生产中请通过环境变量 MP_API_KEY 配置;此处为演示而提供回退到你给的 API Key(不建议硬编码)[3] +DEFAULT_MP_API_KEY = os.getenv("MP_API_KEY", "X0NqhxkFeupGy7k09HRmLS6PI4VmvTBW") + +# ========= 数据模型 ========= + +class LatticeParameters(BaseModel): + a: float | None = Field(default=None, description="a 轴长度 (Å)") + b: float | None = Field(default=None, description="b 轴长度 (Å)") + c: float | None = Field(default=None, description="c 轴长度 (Å)") + alpha: float | None = Field(default=None, description="α 角 (度)") + beta: float | None = Field(default=None, description="β 角 (度)") + gamma: float | None = Field(default=None, description="γ 角 (度)") + + +class CIFCandidate(BaseModel): + id: str = Field(description="Materials Project ID,如 mp-149") + source: Literal["mp"] = Field(default="mp", description="来源库") + chemical_formula: str | None = Field(default=None, description="化学式(pretty)") + space_group: str | None = Field(default=None, description="空间群符号") + cell_parameters: LatticeParameters | None = Field(default=None, description="晶格参数") + bandgap: float | None = Field(default=None, description="带隙 (eV)") + stability_energy: float | None = Field(default=None, description="E_above_hull (eV)") + magnetic_ordering: str | None = Field(default=None, description="磁性信息") + has_cif: bool = Field(default=True, description="是否存在可导出的 CIF(只要有结构即可)") + extras: dict[str, Any] | None = Field(default=None, description="原始文档片段(健壮解析)") + + +class SearchQuery(BaseModel): + # 基础化学约束 + include_elements: list[str] | None = Field(default=None, description="包含元素集合(如 ['Si','O'])") # [4] + exclude_elements: list[str] | None = Field(default=None, description="排除元素集合(如 ['Li'])") + formula: str | None = Field(default=None, description="化学式片段(本地二次过滤)") + + # 结构与性能约束(summary 可过滤带隙与 E_above_hull)[4] + bandgap_min: float | None = None + bandgap_max: float | None = None + e_above_hull_max: float | None = None + + # 空间群、晶格范围(在本地二次过滤) + space_group: str | None = None + cell_a_min: float | None = None + cell_a_max: float | None = None + cell_b_min: float | None = None + cell_b_max: float | None = None + cell_c_min: float | None = None + cell_c_max: float | None = None + + magnetic_ordering: str | None = None # 本地二次过滤 + + page: int = Field(default=1, ge=1, description="页码(仅本地分页;MP 端点请自行扩展)") + per_page: int = Field(default=50, ge=1, le=200, description="每页数量(本地分页)") + sort_by: str | None = Field(default=None, description="排序字段(bandgap/e_above_hull等,本地排序)") + sort_order: Literal["asc", "desc"] | None = Field(default=None, description="排序方向") + + +class CIFText(BaseModel): + id: str + source: Literal["mp"] + text: str = Field(description="CIF 文件文本内容") + metadata: dict[str, Any] | None = Field(default=None, description="附加元数据(如结构来源)") + + +class FilterCriteria(BaseModel): + include_elements: list[str] | None = None + exclude_elements: list[str] | None = None + formula: str | None = None + space_group: str | None = None + cell_a_min: float | None = None + cell_a_max: float | None = None + cell_b_min: float | None = None + cell_b_max: float | None = None + cell_c_min: float | None = None + cell_c_max: float | None = None + bandgap_min: float | None = None + bandgap_max: float | None = None + e_above_hull_max: float | None = None + magnetic_ordering: str | None = None + + +class BulkDownloadItem(BaseModel): + id: str + source: Literal["mp"] = "mp" + success: bool + text: str | None = None + error: str | None = None + + +class BulkDownloadResult(BaseModel): + items: list[BulkDownloadItem] + total: int + succeeded: int + failed: int + + +# ========= Lifespan ========= + +@dataclass +class AppContext: + mpr: MPRester + api_key: str + + +@asynccontextmanager +async def lifespan(_server: FastMCP) -> AsyncIterator[AppContext]: + # 推荐使用 MP_API_KEY 环境变量;此处回退到提供的 Key 字符串 [3] + api_key = DEFAULT_MP_API_KEY + mpr = MPRester(api_key) # 会话管理在 mp-api 内部 [3] + try: + yield AppContext(mpr=mpr, api_key=api_key) + finally: + # mp-api 未强制要求手动关闭;若未来提供关闭方法,可在此清理 [3] + pass + + +# ========= 服务器 ========= + +mcp = FastMCP( + name="Materials Project CIF", + instructions="通过 Materials Project API 检索并导出 CIF 文件。", + lifespan=lifespan, + streamable_http_path="/", # 便于在 Starlette 下挂载到子路径 [1] +) + + +# ========= 工具实现 ========= + +def _match_elements_set(text: str | None, required: list[str] | None, excluded: list[str] | None) -> bool: + if text is None: + return not required and not excluded + s = text.upper() + if required: + for el in required: + if el.upper() not in s: + return False + if excluded: + for el in excluded: + if el.upper() in s: + return False + return True + + +def _lattice_from_structure(struct: Structure | None) -> LatticeParameters | None: + if struct is None: + return None + lat = struct.lattice + return LatticeParameters( + a=float(lat.a), + b=float(lat.b), + c=float(lat.c), + alpha=float(lat.alpha), + beta=float(lat.beta), + gamma=float(lat.gamma), + ) + + +def _sg_symbol_from_doc(doc: Any) -> str | None: + """ + 尝试从 materials.search 返回的文档中提取空间群符号。 + MP 文档中 symmetry 通常包含 spacegroup 信息(不同版本字段路径可能变化)[4] + """ + try: + sym = getattr(doc, "symmetry", None) + if sym and getattr(sym, "spacegroup", None): + sg = sym.spacegroup + # 常见字段:symbol 或 international + return getattr(sg, "symbol", None) or getattr(sg, "international", None) + except Exception: + pass + return None + + +async def _to_thread(fn, *args, **kwargs): + return await asyncio.to_thread(fn, *args, **kwargs) + + +def _parse_elements_from_formula(formula: str) -> list[str]: + """从化学式中解析元素符号(简易版)。""" + if not formula: + return [] + elems = re.findall(r"[A-Z][a-z]?", formula) + seen, ordered = set(), [] + for e in elems: + if e not in seen: + seen.add(e) + ordered.append(e) + return ordered + +@mcp.tool() +async def search_cifs( + query: SearchQuery, + ctx: Context[ServerSession, AppContext], +) -> list[CIFCandidate]: + """ + 通过 summary 做初筛(限制返回数量),再用 materials 拉取当前页的结构与对称性。 + 修复点: + - 仅请求当前页(chunk_size/per_page + num_chunks=1),避免全库下载超时 + - 绝不在返回值中包含 mp-api 文档对象或 Structure,extras 只放可序列化轻量字段 + """ + app = ctx.request_context.lifespan_context + mpr = app.mpr + + # 将 formula 自动转为元素集合(若用户未显式提供) + include_elements = query.include_elements or _parse_elements_from_formula(query.formula or "") + + # 至少要有一个约束,避免全库扫描 + if not include_elements and not query.exclude_elements and query.bandgap_min is None \ + and query.bandgap_max is None and query.e_above_hull_max is None: + raise ValueError("请至少提供一个过滤条件,例如 include_elements、bandgap 或 energy_above_hull 上限,以避免超时。") + + # 1) summary 初筛:只取“当前页”的文档(轻字段) + summary_kwargs: dict[str, Any] = {} + if include_elements: + summary_kwargs["elements"] = include_elements # MP 支持按元素集合检索 + if query.exclude_elements: + summary_kwargs["exclude_elements"] = query.exclude_elements + if query.bandgap_min is not None or query.bandgap_max is not None: + summary_kwargs["band_gap"] = ( + float(query.bandgap_min) if query.bandgap_min is not None else None, + float(query.bandgap_max) if query.bandgap_max is not None else None, + ) + if query.e_above_hull_max is not None: + summary_kwargs["energy_above_hull"] = (None, float(query.e_above_hull_max)) + + fields_summary = [ + "material_id", + "formula_pretty", + "band_gap", + "energy_above_hull", + "ordering", + ] + + # 关键:限制 summary 返回规模 + try: + summary_docs = await asyncio.to_thread( + mpr.materials.summary.search, + fields=fields_summary, + chunk_size=query.per_page, # 每块(页)大小 + num_chunks=1, # 只取一块 => 只取当前页数量 + **summary_kwargs, + ) + except TypeError: + # 如果你的 mp-api 版本不支持上述参数,提示升级,避免一次性抓回十几万条导致超时 + raise RuntimeError( + "当前 mp-api 版本不支持 chunk_size/num_chunks,请升级 mp-api(pip install -U mp_api)。" + ) + + if not summary_docs: + await ctx.debug("summary 检索结果为空") + return [] + + # 本地排序(如需) + results_all: list[CIFCandidate] = [] + for sdoc in summary_docs: + e_hull = getattr(sdoc, "energy_above_hull", None) + bandgap = getattr(sdoc, "band_gap", None) + ordering = getattr(sdoc, "ordering", None) + + # 暂不放 structure/doc 到 extras,避免序列化错误 + results_all.append( + CIFCandidate( + id=sdoc.material_id, + source="mp", + chemical_formula=getattr(sdoc, "formula_pretty", None), + space_group=None, + cell_parameters=None, + bandgap=float(bandgap) if bandgap is not None else None, + stability_energy=float(e_hull) if e_hull is not None else None, + magnetic_ordering=ordering, + has_cif=True, + extras={ + "summary": { + "material_id": sdoc.material_id, + "formula_pretty": getattr(sdoc, "formula_pretty", None), + "band_gap": float(bandgap) if bandgap is not None else None, + "energy_above_hull": float(e_hull) if e_hull is not None else None, + "ordering": ordering, + } + }, + ) + ) + + # 本地分页:summary 已限制返回条数,这里只做安全截断和排序 + if query.sort_by: + reverse = query.sort_order == "desc" + key_map = { + "bandgap": lambda c: (c.bandgap if c.bandgap is not None else float("inf")), + "e_above_hull": lambda c: (c.stability_energy if c.stability_energy is not None else float("inf")), + "energy_above_hull": lambda c: (c.stability_energy if c.stability_energy is not None else float("inf")), + } + if query.sort_by in key_map: + results_all.sort(key=key_map[query.sort_by], reverse=reverse) + + page_slice = results_all[: query.per_page] + page_ids = [c.id for c in page_slice] + + # 2) 只为当前页材料获取结构与对称性 + if not page_ids: + return [] + + fields_materials = ["material_id", "structure", "symmetry"] + materials_docs = await asyncio.to_thread( + mpr.materials.search, + material_ids=page_ids, + fields=fields_materials, + ) + mat_by_id = {doc.material_id: doc for doc in materials_docs} + + # 合并结构/空间群,仍然不返回不可序列化对象 + merged: list[CIFCandidate] = [] + for c in page_slice: + mdoc = mat_by_id.get(c.id) + struct = getattr(mdoc, "structure", None) if mdoc else None + + # 空间群符号 + sg = None + try: + sym = getattr(mdoc, "symmetry", None) + if sym and getattr(sym, "spacegroup", None): + sgobj = sym.spacegroup + sg = getattr(sgobj, "symbol", None) or getattr(sgobj, "international", None) + except Exception: + pass + + c.space_group = sg + c.cell_parameters = _lattice_from_structure(struct) + c.has_cif = struct is not None + # 保持 extras 为轻量 JSON + c.extras = { + **(c.extras or {}), + "materials": {"has_structure": struct is not None}, + } + merged.append(c) + + await ctx.debug(f"返回 {len(merged)} 条(已限制每页 {query.per_page} 条)") + return merged + +@mcp.tool() +async def get_cif( + id: str, + ctx: Context[ServerSession, AppContext], +) -> CIFText: + """ + 使用 materials.search 获取结构并导出 CIF 文本。 + - fields=["structure"] 以提升速度 [4] + - 结构到 CIF 通过 pymatgen 的 Structure.to(fmt="cif") + """ + app = ctx.request_context.lifespan_context + mpr = app.mpr + + await ctx.info(f"获取结构并导出 CIF:{id}") + docs = await _to_thread(mpr.materials.search, material_ids=[id], fields=["material_id", "structure"]) # [4] + if not docs: + raise ValueError(f"未找到材料:{id}") + + doc = docs[0] + struct: Structure | None = getattr(doc, "structure", None) + if struct is None: + raise ValueError(f"材料 {id} 无结构数据,无法导出 CIF") + + cif_text = struct.to(fmt="cif") # 直接输出 CIF 文本 + return CIFText(id=id, source="mp", text=cif_text, metadata={"from": "materials.search"}) + + +@mcp.tool() +async def filter_cifs( + candidates: list[CIFCandidate], + criteria: FilterCriteria, + ctx: Context[ServerSession, AppContext], +) -> list[CIFCandidate]: + def in_range(v: float | None, vmin: float | None, vmax: float | None) -> bool: + if v is None: + return vmin is None and vmax is None + if vmin is not None and v < vmin: + return False + if vmax is not None and v > vmax: + return False + return True + + filtered: list[CIFCandidate] = [] + for c in candidates: + # 元素集合近似匹配(基于化学式字符串) + if not _match_elements_set(c.chemical_formula, criteria.include_elements, criteria.exclude_elements): + continue + # 化学式包含判断 + if criteria.formula: + if not c.chemical_formula or criteria.formula.upper() not in c.chemical_formula.upper(): + continue + # 空间群 + if criteria.space_group: + spg = (c.space_group or "").lower() + if criteria.space_group.lower() not in spg: + continue + # 晶格范围 + cp = c.cell_parameters or LatticeParameters() + if not in_range(cp.a, criteria.cell_a_min, criteria.cell_a_max): + continue + if not in_range(cp.b, criteria.cell_b_min, criteria.cell_b_max): + continue + if not in_range(cp.c, criteria.cell_c_min, criteria.cell_c_max): + continue + # 带隙与稳定性 + if not in_range(c.bandgap, criteria.bandgap_min, criteria.bandgap_max): + continue + if criteria.e_above_hull_max is not None: + if c.stability_energy is None or c.stability_energy > criteria.e_above_hull_max: + continue + # 磁性 + if criteria.magnetic_ordering: + mo = "" # 当前未填充,保留兼容 + if criteria.magnetic_ordering.upper() not in mo.upper(): + continue + + filtered.append(c) + + await ctx.debug(f"筛选后:{len(filtered)} / 原始 {len(candidates)}") + return filtered + + +@mcp.tool() +async def download_cif_bulk( + ids: list[str], + concurrency: int = 4, + ctx: Context[ServerSession, AppContext] | None = None, +) -> BulkDownloadResult: + semaphore = asyncio.Semaphore(max(1, concurrency)) + items: list[BulkDownloadItem] = [] + + async def worker(mid: str, idx: int) -> None: + nonlocal items + try: + async with semaphore: + # 复用当前 ctx + assert ctx is not None + cif = await get_cif(id=mid, ctx=ctx) + items.append(BulkDownloadItem(id=mid, success=True, text=cif.text)) + except Exception as e: + items.append(BulkDownloadItem(id=mid, success=False, error=str(e))) + if ctx is not None: + progress = (idx + 1) / len(ids) if len(ids) > 0 else 1.0 + await ctx.report_progress(progress=progress, total=1.0, message=f"{idx + 1}/{len(ids)}") # [1] + + tasks = [worker(mid, i) for i, mid in enumerate(ids)] + await asyncio.gather(*tasks) + + succeeded = sum(1 for it in items if it.success) + failed = len(items) - succeeded + return BulkDownloadResult(items=items, total=len(items), succeeded=succeeded, failed=failed) + + +# ========= 资源:cif://mp/{id} ========= + +@mcp.resource("cif://mp/{id}") +async def cif_resource(id: str) -> str: + """ + 基于 URI 的 CIF 内容读取。资源函数无法直接拿到 Context, + 这里简单地新的 MPRester 会话读取结构并转 CIF(仅资源路径用途)。 + """ + # 资源中使用一次性会话(不复用 lifespan),适合偶发读取 [1] + mpr = MPRester(DEFAULT_MP_API_KEY) # [3] + try: + docs = mpr.materials.search(material_ids=[id], fields=["material_id", "structure"]) # [4] + if not docs or getattr(docs[0], "structure", None) is None: + raise ValueError(f"材料 {id} 无结构数据") + struct: Structure = docs[0].structure + return struct.to(fmt="cif") + finally: + pass + + +def create_materials_mcp() -> FastMCP: + """供 Starlette 挂载的工厂函数。""" + return mcp + + +if __name__ == "__main__": + # 独立运行(非挂载模式),使用 Streamable HTTP 传输 [1] + mcp.run(transport="streamable-http") \ No newline at end of file