mcp-gen-material

This commit is contained in:
2025-10-15 14:43:57 +08:00
parent 9b0f835575
commit b11bf2417b
2 changed files with 533 additions and 9 deletions

View File

@@ -1,31 +1,42 @@
# main.py # main.py
# 将 Materials CIF MCP 与 System Tools MCP 一起挂载到 Starlette。
# 关键点:
# - 在 lifespan 中启动每个 MCP 的 session_manager.run()(参考 SDK README 的 Starlette 挂载示例与 streamable_http_app 用法 [1]
# - 通过 Mount 指定各自的子路径(如 /system 与 /materials
import contextlib import contextlib
from starlette.applications import Starlette from starlette.applications import Starlette
from starlette.routing import Mount from starlette.routing import Mount
from test_tools import create_test_mcp from system_tools import create_system_mcp
from system_tools import create_system_mcp # 如果暂时不用可先不挂 from materialproject_mcp import create_materials_mcp
# 创建 MCP 实例 # 创建 MCP 实例
test_mcp = create_test_mcp()
system_mcp = create_system_mcp() system_mcp = create_system_mcp()
materials_mcp = create_materials_mcp()
# 关键:在 Starlette 的 lifespan 中启动 MCP 的 session manager # 在 Starlette 的 lifespan 中启动 MCP 的 session manager
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def lifespan(app: Starlette): async def lifespan(app: Starlette):
async with contextlib.AsyncExitStack() as stack: async with contextlib.AsyncExitStack() as stack:
# await stack.enter_async_context(test_mcp.session_manager.run())
await stack.enter_async_context(system_mcp.session_manager.run()) await stack.enter_async_context(system_mcp.session_manager.run())
await stack.enter_async_context(materials_mcp.session_manager.run())
yield # 服务器运行期间 yield # 服务器运行期间
# 退出时自动清理 # 退出时自动清理
# 挂载两个 MCP 的 Streamable HTTP App
app = Starlette( app = Starlette(
lifespan=lifespan, lifespan=lifespan,
routes=[ routes=[
# Mount("/test", app=test_mcp.streamable_http_app()),
Mount("/system", app=system_mcp.streamable_http_app()), Mount("/system", app=system_mcp.streamable_http_app()),
Mount("/materials", app=materials_mcp.streamable_http_app()),
], ],
) )
# 启动代码为uvicorn main:app --host 0.0.0.0 --port 8000 # 启动命令(终端执行):
# url为http://localhost:8000/system # uvicorn main:app --host 0.0.0.0 --port 8000
# 访问:
# http://localhost:8000/system
# http://localhost:8000/materials
#
# 如果需要浏览器客户端访问CORS 暴露 Mcp-Session-Id请参考 README 中的 CORS 配置示例 [1]

513
mcp/materialproject_mcp.py Normal file
View File

@@ -0,0 +1,513 @@
# materials_mp_cif_mcp.py
#
# 说明:
# - 这是一个基于 FastMCP 的“Materials Project CIF”服务器提供
# 1) search_cifs使用 Materials Project 的 mp-api 进行检索两步summary 过滤 + materials 获取结构)
# 2) get_cif按材料 ID 获取结构并导出 CIF 文本pymatgen
# 3) filter_cifs在本地对候选进行硬筛选
# 4) download_cif_bulk批量下载 CIF带并发与进度
# 5) 资源cif://mp/{id} 直接读取 CIF 文本
#
# - 依赖pip install "mcp[cli]" mp_api pymatgen
# - API Key
# 默认从环境变量 MP_API_KEY 读取;若未设置则回退为你提供的 Key仅用于演示不建议在生产中硬编码[3]
# - MCP/Context/Lifespan 的使用方法参考 SDK 文档(工具、结构化输出、进度等)[1]
#
# 引用:
# - MPRester/MP_API_KEY 的用法与推荐的会话管理方式见 Materials Project 文档 [3]
# - summary.search 与 materials.search 的查询与 fields 限制示例见文档 [4]
# - FastMCP 的工具、资源、Context、结构化输出见 README 示例 [1]
import os
import asyncio
from dataclasses import dataclass
from typing import Any, Literal
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
import re
from pydantic import BaseModel, Field
from mp_api.client import MPRester # [3]
from pymatgen.core import Structure # 用于转 CIF
from mcp.server.fastmcp import FastMCP, Context # [1]
from mcp.server.session import ServerSession # [1]
# ========= 配置 =========
# 生产中请通过环境变量 MP_API_KEY 配置;此处为演示而提供回退到你给的 API Key不建议硬编码[3]
DEFAULT_MP_API_KEY = os.getenv("MP_API_KEY", "X0NqhxkFeupGy7k09HRmLS6PI4VmvTBW")
# ========= 数据模型 =========
class LatticeParameters(BaseModel):
a: float | None = Field(default=None, description="a 轴长度 (Å)")
b: float | None = Field(default=None, description="b 轴长度 (Å)")
c: float | None = Field(default=None, description="c 轴长度 (Å)")
alpha: float | None = Field(default=None, description="α 角 (度)")
beta: float | None = Field(default=None, description="β 角 (度)")
gamma: float | None = Field(default=None, description="γ 角 (度)")
class CIFCandidate(BaseModel):
id: str = Field(description="Materials Project ID如 mp-149")
source: Literal["mp"] = Field(default="mp", description="来源库")
chemical_formula: str | None = Field(default=None, description="化学式pretty")
space_group: str | None = Field(default=None, description="空间群符号")
cell_parameters: LatticeParameters | None = Field(default=None, description="晶格参数")
bandgap: float | None = Field(default=None, description="带隙 (eV)")
stability_energy: float | None = Field(default=None, description="E_above_hull (eV)")
magnetic_ordering: str | None = Field(default=None, description="磁性信息")
has_cif: bool = Field(default=True, description="是否存在可导出的 CIF只要有结构即可")
extras: dict[str, Any] | None = Field(default=None, description="原始文档片段(健壮解析)")
class SearchQuery(BaseModel):
# 基础化学约束
include_elements: list[str] | None = Field(default=None, description="包含元素集合(如 ['Si','O']") # [4]
exclude_elements: list[str] | None = Field(default=None, description="排除元素集合(如 ['Li']")
formula: str | None = Field(default=None, description="化学式片段(本地二次过滤)")
# 结构与性能约束summary 可过滤带隙与 E_above_hull[4]
bandgap_min: float | None = None
bandgap_max: float | None = None
e_above_hull_max: float | None = None
# 空间群、晶格范围(在本地二次过滤)
space_group: str | None = None
cell_a_min: float | None = None
cell_a_max: float | None = None
cell_b_min: float | None = None
cell_b_max: float | None = None
cell_c_min: float | None = None
cell_c_max: float | None = None
magnetic_ordering: str | None = None # 本地二次过滤
page: int = Field(default=1, ge=1, description="页码仅本地分页MP 端点请自行扩展)")
per_page: int = Field(default=50, ge=1, le=200, description="每页数量(本地分页)")
sort_by: str | None = Field(default=None, description="排序字段bandgap/e_above_hull等本地排序")
sort_order: Literal["asc", "desc"] | None = Field(default=None, description="排序方向")
class CIFText(BaseModel):
id: str
source: Literal["mp"]
text: str = Field(description="CIF 文件文本内容")
metadata: dict[str, Any] | None = Field(default=None, description="附加元数据(如结构来源)")
class FilterCriteria(BaseModel):
include_elements: list[str] | None = None
exclude_elements: list[str] | None = None
formula: str | None = None
space_group: str | None = None
cell_a_min: float | None = None
cell_a_max: float | None = None
cell_b_min: float | None = None
cell_b_max: float | None = None
cell_c_min: float | None = None
cell_c_max: float | None = None
bandgap_min: float | None = None
bandgap_max: float | None = None
e_above_hull_max: float | None = None
magnetic_ordering: str | None = None
class BulkDownloadItem(BaseModel):
id: str
source: Literal["mp"] = "mp"
success: bool
text: str | None = None
error: str | None = None
class BulkDownloadResult(BaseModel):
items: list[BulkDownloadItem]
total: int
succeeded: int
failed: int
# ========= Lifespan =========
@dataclass
class AppContext:
mpr: MPRester
api_key: str
@asynccontextmanager
async def lifespan(_server: FastMCP) -> AsyncIterator[AppContext]:
# 推荐使用 MP_API_KEY 环境变量;此处回退到提供的 Key 字符串 [3]
api_key = DEFAULT_MP_API_KEY
mpr = MPRester(api_key) # 会话管理在 mp-api 内部 [3]
try:
yield AppContext(mpr=mpr, api_key=api_key)
finally:
# mp-api 未强制要求手动关闭;若未来提供关闭方法,可在此清理 [3]
pass
# ========= 服务器 =========
mcp = FastMCP(
name="Materials Project CIF",
instructions="通过 Materials Project API 检索并导出 CIF 文件。",
lifespan=lifespan,
streamable_http_path="/", # 便于在 Starlette 下挂载到子路径 [1]
)
# ========= 工具实现 =========
def _match_elements_set(text: str | None, required: list[str] | None, excluded: list[str] | None) -> bool:
if text is None:
return not required and not excluded
s = text.upper()
if required:
for el in required:
if el.upper() not in s:
return False
if excluded:
for el in excluded:
if el.upper() in s:
return False
return True
def _lattice_from_structure(struct: Structure | None) -> LatticeParameters | None:
if struct is None:
return None
lat = struct.lattice
return LatticeParameters(
a=float(lat.a),
b=float(lat.b),
c=float(lat.c),
alpha=float(lat.alpha),
beta=float(lat.beta),
gamma=float(lat.gamma),
)
def _sg_symbol_from_doc(doc: Any) -> str | None:
"""
尝试从 materials.search 返回的文档中提取空间群符号。
MP 文档中 symmetry 通常包含 spacegroup 信息(不同版本字段路径可能变化)[4]
"""
try:
sym = getattr(doc, "symmetry", None)
if sym and getattr(sym, "spacegroup", None):
sg = sym.spacegroup
# 常见字段symbol 或 international
return getattr(sg, "symbol", None) or getattr(sg, "international", None)
except Exception:
pass
return None
async def _to_thread(fn, *args, **kwargs):
return await asyncio.to_thread(fn, *args, **kwargs)
def _parse_elements_from_formula(formula: str) -> list[str]:
"""从化学式中解析元素符号(简易版)。"""
if not formula:
return []
elems = re.findall(r"[A-Z][a-z]?", formula)
seen, ordered = set(), []
for e in elems:
if e not in seen:
seen.add(e)
ordered.append(e)
return ordered
@mcp.tool()
async def search_cifs(
query: SearchQuery,
ctx: Context[ServerSession, AppContext],
) -> list[CIFCandidate]:
"""
通过 summary 做初筛(限制返回数量),再用 materials 拉取当前页的结构与对称性。
修复点:
- 仅请求当前页chunk_size/per_page + num_chunks=1避免全库下载超时
- 绝不在返回值中包含 mp-api 文档对象或 Structureextras 只放可序列化轻量字段
"""
app = ctx.request_context.lifespan_context
mpr = app.mpr
# 将 formula 自动转为元素集合(若用户未显式提供)
include_elements = query.include_elements or _parse_elements_from_formula(query.formula or "")
# 至少要有一个约束,避免全库扫描
if not include_elements and not query.exclude_elements and query.bandgap_min is None \
and query.bandgap_max is None and query.e_above_hull_max is None:
raise ValueError("请至少提供一个过滤条件,例如 include_elements、bandgap 或 energy_above_hull 上限,以避免超时。")
# 1) summary 初筛:只取“当前页”的文档(轻字段)
summary_kwargs: dict[str, Any] = {}
if include_elements:
summary_kwargs["elements"] = include_elements # MP 支持按元素集合检索
if query.exclude_elements:
summary_kwargs["exclude_elements"] = query.exclude_elements
if query.bandgap_min is not None or query.bandgap_max is not None:
summary_kwargs["band_gap"] = (
float(query.bandgap_min) if query.bandgap_min is not None else None,
float(query.bandgap_max) if query.bandgap_max is not None else None,
)
if query.e_above_hull_max is not None:
summary_kwargs["energy_above_hull"] = (None, float(query.e_above_hull_max))
fields_summary = [
"material_id",
"formula_pretty",
"band_gap",
"energy_above_hull",
"ordering",
]
# 关键:限制 summary 返回规模
try:
summary_docs = await asyncio.to_thread(
mpr.materials.summary.search,
fields=fields_summary,
chunk_size=query.per_page, # 每块(页)大小
num_chunks=1, # 只取一块 => 只取当前页数量
**summary_kwargs,
)
except TypeError:
# 如果你的 mp-api 版本不支持上述参数,提示升级,避免一次性抓回十几万条导致超时
raise RuntimeError(
"当前 mp-api 版本不支持 chunk_size/num_chunks请升级 mp-apipip install -U mp_api"
)
if not summary_docs:
await ctx.debug("summary 检索结果为空")
return []
# 本地排序(如需)
results_all: list[CIFCandidate] = []
for sdoc in summary_docs:
e_hull = getattr(sdoc, "energy_above_hull", None)
bandgap = getattr(sdoc, "band_gap", None)
ordering = getattr(sdoc, "ordering", None)
# 暂不放 structure/doc 到 extras避免序列化错误
results_all.append(
CIFCandidate(
id=sdoc.material_id,
source="mp",
chemical_formula=getattr(sdoc, "formula_pretty", None),
space_group=None,
cell_parameters=None,
bandgap=float(bandgap) if bandgap is not None else None,
stability_energy=float(e_hull) if e_hull is not None else None,
magnetic_ordering=ordering,
has_cif=True,
extras={
"summary": {
"material_id": sdoc.material_id,
"formula_pretty": getattr(sdoc, "formula_pretty", None),
"band_gap": float(bandgap) if bandgap is not None else None,
"energy_above_hull": float(e_hull) if e_hull is not None else None,
"ordering": ordering,
}
},
)
)
# 本地分页summary 已限制返回条数,这里只做安全截断和排序
if query.sort_by:
reverse = query.sort_order == "desc"
key_map = {
"bandgap": lambda c: (c.bandgap if c.bandgap is not None else float("inf")),
"e_above_hull": lambda c: (c.stability_energy if c.stability_energy is not None else float("inf")),
"energy_above_hull": lambda c: (c.stability_energy if c.stability_energy is not None else float("inf")),
}
if query.sort_by in key_map:
results_all.sort(key=key_map[query.sort_by], reverse=reverse)
page_slice = results_all[: query.per_page]
page_ids = [c.id for c in page_slice]
# 2) 只为当前页材料获取结构与对称性
if not page_ids:
return []
fields_materials = ["material_id", "structure", "symmetry"]
materials_docs = await asyncio.to_thread(
mpr.materials.search,
material_ids=page_ids,
fields=fields_materials,
)
mat_by_id = {doc.material_id: doc for doc in materials_docs}
# 合并结构/空间群,仍然不返回不可序列化对象
merged: list[CIFCandidate] = []
for c in page_slice:
mdoc = mat_by_id.get(c.id)
struct = getattr(mdoc, "structure", None) if mdoc else None
# 空间群符号
sg = None
try:
sym = getattr(mdoc, "symmetry", None)
if sym and getattr(sym, "spacegroup", None):
sgobj = sym.spacegroup
sg = getattr(sgobj, "symbol", None) or getattr(sgobj, "international", None)
except Exception:
pass
c.space_group = sg
c.cell_parameters = _lattice_from_structure(struct)
c.has_cif = struct is not None
# 保持 extras 为轻量 JSON
c.extras = {
**(c.extras or {}),
"materials": {"has_structure": struct is not None},
}
merged.append(c)
await ctx.debug(f"返回 {len(merged)} 条(已限制每页 {query.per_page} 条)")
return merged
@mcp.tool()
async def get_cif(
id: str,
ctx: Context[ServerSession, AppContext],
) -> CIFText:
"""
使用 materials.search 获取结构并导出 CIF 文本。
- fields=["structure"] 以提升速度 [4]
- 结构到 CIF 通过 pymatgen 的 Structure.to(fmt="cif")
"""
app = ctx.request_context.lifespan_context
mpr = app.mpr
await ctx.info(f"获取结构并导出 CIF{id}")
docs = await _to_thread(mpr.materials.search, material_ids=[id], fields=["material_id", "structure"]) # [4]
if not docs:
raise ValueError(f"未找到材料:{id}")
doc = docs[0]
struct: Structure | None = getattr(doc, "structure", None)
if struct is None:
raise ValueError(f"材料 {id} 无结构数据,无法导出 CIF")
cif_text = struct.to(fmt="cif") # 直接输出 CIF 文本
return CIFText(id=id, source="mp", text=cif_text, metadata={"from": "materials.search"})
@mcp.tool()
async def filter_cifs(
candidates: list[CIFCandidate],
criteria: FilterCriteria,
ctx: Context[ServerSession, AppContext],
) -> list[CIFCandidate]:
def in_range(v: float | None, vmin: float | None, vmax: float | None) -> bool:
if v is None:
return vmin is None and vmax is None
if vmin is not None and v < vmin:
return False
if vmax is not None and v > vmax:
return False
return True
filtered: list[CIFCandidate] = []
for c in candidates:
# 元素集合近似匹配(基于化学式字符串)
if not _match_elements_set(c.chemical_formula, criteria.include_elements, criteria.exclude_elements):
continue
# 化学式包含判断
if criteria.formula:
if not c.chemical_formula or criteria.formula.upper() not in c.chemical_formula.upper():
continue
# 空间群
if criteria.space_group:
spg = (c.space_group or "").lower()
if criteria.space_group.lower() not in spg:
continue
# 晶格范围
cp = c.cell_parameters or LatticeParameters()
if not in_range(cp.a, criteria.cell_a_min, criteria.cell_a_max):
continue
if not in_range(cp.b, criteria.cell_b_min, criteria.cell_b_max):
continue
if not in_range(cp.c, criteria.cell_c_min, criteria.cell_c_max):
continue
# 带隙与稳定性
if not in_range(c.bandgap, criteria.bandgap_min, criteria.bandgap_max):
continue
if criteria.e_above_hull_max is not None:
if c.stability_energy is None or c.stability_energy > criteria.e_above_hull_max:
continue
# 磁性
if criteria.magnetic_ordering:
mo = "" # 当前未填充,保留兼容
if criteria.magnetic_ordering.upper() not in mo.upper():
continue
filtered.append(c)
await ctx.debug(f"筛选后:{len(filtered)} / 原始 {len(candidates)}")
return filtered
@mcp.tool()
async def download_cif_bulk(
ids: list[str],
concurrency: int = 4,
ctx: Context[ServerSession, AppContext] | None = None,
) -> BulkDownloadResult:
semaphore = asyncio.Semaphore(max(1, concurrency))
items: list[BulkDownloadItem] = []
async def worker(mid: str, idx: int) -> None:
nonlocal items
try:
async with semaphore:
# 复用当前 ctx
assert ctx is not None
cif = await get_cif(id=mid, ctx=ctx)
items.append(BulkDownloadItem(id=mid, success=True, text=cif.text))
except Exception as e:
items.append(BulkDownloadItem(id=mid, success=False, error=str(e)))
if ctx is not None:
progress = (idx + 1) / len(ids) if len(ids) > 0 else 1.0
await ctx.report_progress(progress=progress, total=1.0, message=f"{idx + 1}/{len(ids)}") # [1]
tasks = [worker(mid, i) for i, mid in enumerate(ids)]
await asyncio.gather(*tasks)
succeeded = sum(1 for it in items if it.success)
failed = len(items) - succeeded
return BulkDownloadResult(items=items, total=len(items), succeeded=succeeded, failed=failed)
# ========= 资源cif://mp/{id} =========
@mcp.resource("cif://mp/{id}")
async def cif_resource(id: str) -> str:
"""
基于 URI 的 CIF 内容读取。资源函数无法直接拿到 Context
这里简单地新的 MPRester 会话读取结构并转 CIF仅资源路径用途
"""
# 资源中使用一次性会话(不复用 lifespan适合偶发读取 [1]
mpr = MPRester(DEFAULT_MP_API_KEY) # [3]
try:
docs = mpr.materials.search(material_ids=[id], fields=["material_id", "structure"]) # [4]
if not docs or getattr(docs[0], "structure", None) is None:
raise ValueError(f"材料 {id} 无结构数据")
struct: Structure = docs[0].structure
return struct.to(fmt="cif")
finally:
pass
def create_materials_mcp() -> FastMCP:
"""供 Starlette 挂载的工厂函数。"""
return mcp
if __name__ == "__main__":
# 独立运行(非挂载模式),使用 Streamable HTTP 传输 [1]
mcp.run(transport="streamable-http")