Files
solidstate-tools/mcp/materialproject_mcp.py
2025-10-15 14:43:57 +08:00

513 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# materials_mp_cif_mcp.py
#
# 说明:
# - 这是一个基于 FastMCP 的“Materials Project CIF”服务器提供
# 1) search_cifs使用 Materials Project 的 mp-api 进行检索两步summary 过滤 + materials 获取结构)
# 2) get_cif按材料 ID 获取结构并导出 CIF 文本pymatgen
# 3) filter_cifs在本地对候选进行硬筛选
# 4) download_cif_bulk批量下载 CIF带并发与进度
# 5) 资源cif://mp/{id} 直接读取 CIF 文本
#
# - 依赖pip install "mcp[cli]" mp_api pymatgen
# - API Key
# 默认从环境变量 MP_API_KEY 读取;若未设置则回退为你提供的 Key仅用于演示不建议在生产中硬编码[3]
# - MCP/Context/Lifespan 的使用方法参考 SDK 文档(工具、结构化输出、进度等)[1]
#
# 引用:
# - MPRester/MP_API_KEY 的用法与推荐的会话管理方式见 Materials Project 文档 [3]
# - summary.search 与 materials.search 的查询与 fields 限制示例见文档 [4]
# - FastMCP 的工具、资源、Context、结构化输出见 README 示例 [1]
import os
import asyncio
from dataclasses import dataclass
from typing import Any, Literal
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
import re
from pydantic import BaseModel, Field
from mp_api.client import MPRester # [3]
from pymatgen.core import Structure # 用于转 CIF
from mcp.server.fastmcp import FastMCP, Context # [1]
from mcp.server.session import ServerSession # [1]
# ========= 配置 =========
# 生产中请通过环境变量 MP_API_KEY 配置;此处为演示而提供回退到你给的 API Key不建议硬编码[3]
DEFAULT_MP_API_KEY = os.getenv("MP_API_KEY", "X0NqhxkFeupGy7k09HRmLS6PI4VmvTBW")
# ========= 数据模型 =========
class LatticeParameters(BaseModel):
a: float | None = Field(default=None, description="a 轴长度 (Å)")
b: float | None = Field(default=None, description="b 轴长度 (Å)")
c: float | None = Field(default=None, description="c 轴长度 (Å)")
alpha: float | None = Field(default=None, description="α 角 (度)")
beta: float | None = Field(default=None, description="β 角 (度)")
gamma: float | None = Field(default=None, description="γ 角 (度)")
class CIFCandidate(BaseModel):
id: str = Field(description="Materials Project ID如 mp-149")
source: Literal["mp"] = Field(default="mp", description="来源库")
chemical_formula: str | None = Field(default=None, description="化学式pretty")
space_group: str | None = Field(default=None, description="空间群符号")
cell_parameters: LatticeParameters | None = Field(default=None, description="晶格参数")
bandgap: float | None = Field(default=None, description="带隙 (eV)")
stability_energy: float | None = Field(default=None, description="E_above_hull (eV)")
magnetic_ordering: str | None = Field(default=None, description="磁性信息")
has_cif: bool = Field(default=True, description="是否存在可导出的 CIF只要有结构即可")
extras: dict[str, Any] | None = Field(default=None, description="原始文档片段(健壮解析)")
class SearchQuery(BaseModel):
# 基础化学约束
include_elements: list[str] | None = Field(default=None, description="包含元素集合(如 ['Si','O']") # [4]
exclude_elements: list[str] | None = Field(default=None, description="排除元素集合(如 ['Li']")
formula: str | None = Field(default=None, description="化学式片段(本地二次过滤)")
# 结构与性能约束summary 可过滤带隙与 E_above_hull[4]
bandgap_min: float | None = None
bandgap_max: float | None = None
e_above_hull_max: float | None = None
# 空间群、晶格范围(在本地二次过滤)
space_group: str | None = None
cell_a_min: float | None = None
cell_a_max: float | None = None
cell_b_min: float | None = None
cell_b_max: float | None = None
cell_c_min: float | None = None
cell_c_max: float | None = None
magnetic_ordering: str | None = None # 本地二次过滤
page: int = Field(default=1, ge=1, description="页码仅本地分页MP 端点请自行扩展)")
per_page: int = Field(default=50, ge=1, le=200, description="每页数量(本地分页)")
sort_by: str | None = Field(default=None, description="排序字段bandgap/e_above_hull等本地排序")
sort_order: Literal["asc", "desc"] | None = Field(default=None, description="排序方向")
class CIFText(BaseModel):
id: str
source: Literal["mp"]
text: str = Field(description="CIF 文件文本内容")
metadata: dict[str, Any] | None = Field(default=None, description="附加元数据(如结构来源)")
class FilterCriteria(BaseModel):
include_elements: list[str] | None = None
exclude_elements: list[str] | None = None
formula: str | None = None
space_group: str | None = None
cell_a_min: float | None = None
cell_a_max: float | None = None
cell_b_min: float | None = None
cell_b_max: float | None = None
cell_c_min: float | None = None
cell_c_max: float | None = None
bandgap_min: float | None = None
bandgap_max: float | None = None
e_above_hull_max: float | None = None
magnetic_ordering: str | None = None
class BulkDownloadItem(BaseModel):
id: str
source: Literal["mp"] = "mp"
success: bool
text: str | None = None
error: str | None = None
class BulkDownloadResult(BaseModel):
items: list[BulkDownloadItem]
total: int
succeeded: int
failed: int
# ========= Lifespan =========
@dataclass
class AppContext:
mpr: MPRester
api_key: str
@asynccontextmanager
async def lifespan(_server: FastMCP) -> AsyncIterator[AppContext]:
# 推荐使用 MP_API_KEY 环境变量;此处回退到提供的 Key 字符串 [3]
api_key = DEFAULT_MP_API_KEY
mpr = MPRester(api_key) # 会话管理在 mp-api 内部 [3]
try:
yield AppContext(mpr=mpr, api_key=api_key)
finally:
# mp-api 未强制要求手动关闭;若未来提供关闭方法,可在此清理 [3]
pass
# ========= 服务器 =========
mcp = FastMCP(
name="Materials Project CIF",
instructions="通过 Materials Project API 检索并导出 CIF 文件。",
lifespan=lifespan,
streamable_http_path="/", # 便于在 Starlette 下挂载到子路径 [1]
)
# ========= 工具实现 =========
def _match_elements_set(text: str | None, required: list[str] | None, excluded: list[str] | None) -> bool:
if text is None:
return not required and not excluded
s = text.upper()
if required:
for el in required:
if el.upper() not in s:
return False
if excluded:
for el in excluded:
if el.upper() in s:
return False
return True
def _lattice_from_structure(struct: Structure | None) -> LatticeParameters | None:
if struct is None:
return None
lat = struct.lattice
return LatticeParameters(
a=float(lat.a),
b=float(lat.b),
c=float(lat.c),
alpha=float(lat.alpha),
beta=float(lat.beta),
gamma=float(lat.gamma),
)
def _sg_symbol_from_doc(doc: Any) -> str | None:
"""
尝试从 materials.search 返回的文档中提取空间群符号。
MP 文档中 symmetry 通常包含 spacegroup 信息(不同版本字段路径可能变化)[4]
"""
try:
sym = getattr(doc, "symmetry", None)
if sym and getattr(sym, "spacegroup", None):
sg = sym.spacegroup
# 常见字段symbol 或 international
return getattr(sg, "symbol", None) or getattr(sg, "international", None)
except Exception:
pass
return None
async def _to_thread(fn, *args, **kwargs):
return await asyncio.to_thread(fn, *args, **kwargs)
def _parse_elements_from_formula(formula: str) -> list[str]:
"""从化学式中解析元素符号(简易版)。"""
if not formula:
return []
elems = re.findall(r"[A-Z][a-z]?", formula)
seen, ordered = set(), []
for e in elems:
if e not in seen:
seen.add(e)
ordered.append(e)
return ordered
@mcp.tool()
async def search_cifs(
query: SearchQuery,
ctx: Context[ServerSession, AppContext],
) -> list[CIFCandidate]:
"""
通过 summary 做初筛(限制返回数量),再用 materials 拉取当前页的结构与对称性。
修复点:
- 仅请求当前页chunk_size/per_page + num_chunks=1避免全库下载超时
- 绝不在返回值中包含 mp-api 文档对象或 Structureextras 只放可序列化轻量字段
"""
app = ctx.request_context.lifespan_context
mpr = app.mpr
# 将 formula 自动转为元素集合(若用户未显式提供)
include_elements = query.include_elements or _parse_elements_from_formula(query.formula or "")
# 至少要有一个约束,避免全库扫描
if not include_elements and not query.exclude_elements and query.bandgap_min is None \
and query.bandgap_max is None and query.e_above_hull_max is None:
raise ValueError("请至少提供一个过滤条件,例如 include_elements、bandgap 或 energy_above_hull 上限,以避免超时。")
# 1) summary 初筛:只取“当前页”的文档(轻字段)
summary_kwargs: dict[str, Any] = {}
if include_elements:
summary_kwargs["elements"] = include_elements # MP 支持按元素集合检索
if query.exclude_elements:
summary_kwargs["exclude_elements"] = query.exclude_elements
if query.bandgap_min is not None or query.bandgap_max is not None:
summary_kwargs["band_gap"] = (
float(query.bandgap_min) if query.bandgap_min is not None else None,
float(query.bandgap_max) if query.bandgap_max is not None else None,
)
if query.e_above_hull_max is not None:
summary_kwargs["energy_above_hull"] = (None, float(query.e_above_hull_max))
fields_summary = [
"material_id",
"formula_pretty",
"band_gap",
"energy_above_hull",
"ordering",
]
# 关键:限制 summary 返回规模
try:
summary_docs = await asyncio.to_thread(
mpr.materials.summary.search,
fields=fields_summary,
chunk_size=query.per_page, # 每块(页)大小
num_chunks=1, # 只取一块 => 只取当前页数量
**summary_kwargs,
)
except TypeError:
# 如果你的 mp-api 版本不支持上述参数,提示升级,避免一次性抓回十几万条导致超时
raise RuntimeError(
"当前 mp-api 版本不支持 chunk_size/num_chunks请升级 mp-apipip install -U mp_api"
)
if not summary_docs:
await ctx.debug("summary 检索结果为空")
return []
# 本地排序(如需)
results_all: list[CIFCandidate] = []
for sdoc in summary_docs:
e_hull = getattr(sdoc, "energy_above_hull", None)
bandgap = getattr(sdoc, "band_gap", None)
ordering = getattr(sdoc, "ordering", None)
# 暂不放 structure/doc 到 extras避免序列化错误
results_all.append(
CIFCandidate(
id=sdoc.material_id,
source="mp",
chemical_formula=getattr(sdoc, "formula_pretty", None),
space_group=None,
cell_parameters=None,
bandgap=float(bandgap) if bandgap is not None else None,
stability_energy=float(e_hull) if e_hull is not None else None,
magnetic_ordering=ordering,
has_cif=True,
extras={
"summary": {
"material_id": sdoc.material_id,
"formula_pretty": getattr(sdoc, "formula_pretty", None),
"band_gap": float(bandgap) if bandgap is not None else None,
"energy_above_hull": float(e_hull) if e_hull is not None else None,
"ordering": ordering,
}
},
)
)
# 本地分页summary 已限制返回条数,这里只做安全截断和排序
if query.sort_by:
reverse = query.sort_order == "desc"
key_map = {
"bandgap": lambda c: (c.bandgap if c.bandgap is not None else float("inf")),
"e_above_hull": lambda c: (c.stability_energy if c.stability_energy is not None else float("inf")),
"energy_above_hull": lambda c: (c.stability_energy if c.stability_energy is not None else float("inf")),
}
if query.sort_by in key_map:
results_all.sort(key=key_map[query.sort_by], reverse=reverse)
page_slice = results_all[: query.per_page]
page_ids = [c.id for c in page_slice]
# 2) 只为当前页材料获取结构与对称性
if not page_ids:
return []
fields_materials = ["material_id", "structure", "symmetry"]
materials_docs = await asyncio.to_thread(
mpr.materials.search,
material_ids=page_ids,
fields=fields_materials,
)
mat_by_id = {doc.material_id: doc for doc in materials_docs}
# 合并结构/空间群,仍然不返回不可序列化对象
merged: list[CIFCandidate] = []
for c in page_slice:
mdoc = mat_by_id.get(c.id)
struct = getattr(mdoc, "structure", None) if mdoc else None
# 空间群符号
sg = None
try:
sym = getattr(mdoc, "symmetry", None)
if sym and getattr(sym, "spacegroup", None):
sgobj = sym.spacegroup
sg = getattr(sgobj, "symbol", None) or getattr(sgobj, "international", None)
except Exception:
pass
c.space_group = sg
c.cell_parameters = _lattice_from_structure(struct)
c.has_cif = struct is not None
# 保持 extras 为轻量 JSON
c.extras = {
**(c.extras or {}),
"materials": {"has_structure": struct is not None},
}
merged.append(c)
await ctx.debug(f"返回 {len(merged)} 条(已限制每页 {query.per_page} 条)")
return merged
@mcp.tool()
async def get_cif(
id: str,
ctx: Context[ServerSession, AppContext],
) -> CIFText:
"""
使用 materials.search 获取结构并导出 CIF 文本。
- fields=["structure"] 以提升速度 [4]
- 结构到 CIF 通过 pymatgen 的 Structure.to(fmt="cif")
"""
app = ctx.request_context.lifespan_context
mpr = app.mpr
await ctx.info(f"获取结构并导出 CIF{id}")
docs = await _to_thread(mpr.materials.search, material_ids=[id], fields=["material_id", "structure"]) # [4]
if not docs:
raise ValueError(f"未找到材料:{id}")
doc = docs[0]
struct: Structure | None = getattr(doc, "structure", None)
if struct is None:
raise ValueError(f"材料 {id} 无结构数据,无法导出 CIF")
cif_text = struct.to(fmt="cif") # 直接输出 CIF 文本
return CIFText(id=id, source="mp", text=cif_text, metadata={"from": "materials.search"})
@mcp.tool()
async def filter_cifs(
candidates: list[CIFCandidate],
criteria: FilterCriteria,
ctx: Context[ServerSession, AppContext],
) -> list[CIFCandidate]:
def in_range(v: float | None, vmin: float | None, vmax: float | None) -> bool:
if v is None:
return vmin is None and vmax is None
if vmin is not None and v < vmin:
return False
if vmax is not None and v > vmax:
return False
return True
filtered: list[CIFCandidate] = []
for c in candidates:
# 元素集合近似匹配(基于化学式字符串)
if not _match_elements_set(c.chemical_formula, criteria.include_elements, criteria.exclude_elements):
continue
# 化学式包含判断
if criteria.formula:
if not c.chemical_formula or criteria.formula.upper() not in c.chemical_formula.upper():
continue
# 空间群
if criteria.space_group:
spg = (c.space_group or "").lower()
if criteria.space_group.lower() not in spg:
continue
# 晶格范围
cp = c.cell_parameters or LatticeParameters()
if not in_range(cp.a, criteria.cell_a_min, criteria.cell_a_max):
continue
if not in_range(cp.b, criteria.cell_b_min, criteria.cell_b_max):
continue
if not in_range(cp.c, criteria.cell_c_min, criteria.cell_c_max):
continue
# 带隙与稳定性
if not in_range(c.bandgap, criteria.bandgap_min, criteria.bandgap_max):
continue
if criteria.e_above_hull_max is not None:
if c.stability_energy is None or c.stability_energy > criteria.e_above_hull_max:
continue
# 磁性
if criteria.magnetic_ordering:
mo = "" # 当前未填充,保留兼容
if criteria.magnetic_ordering.upper() not in mo.upper():
continue
filtered.append(c)
await ctx.debug(f"筛选后:{len(filtered)} / 原始 {len(candidates)}")
return filtered
@mcp.tool()
async def download_cif_bulk(
ids: list[str],
concurrency: int = 4,
ctx: Context[ServerSession, AppContext] | None = None,
) -> BulkDownloadResult:
semaphore = asyncio.Semaphore(max(1, concurrency))
items: list[BulkDownloadItem] = []
async def worker(mid: str, idx: int) -> None:
nonlocal items
try:
async with semaphore:
# 复用当前 ctx
assert ctx is not None
cif = await get_cif(id=mid, ctx=ctx)
items.append(BulkDownloadItem(id=mid, success=True, text=cif.text))
except Exception as e:
items.append(BulkDownloadItem(id=mid, success=False, error=str(e)))
if ctx is not None:
progress = (idx + 1) / len(ids) if len(ids) > 0 else 1.0
await ctx.report_progress(progress=progress, total=1.0, message=f"{idx + 1}/{len(ids)}") # [1]
tasks = [worker(mid, i) for i, mid in enumerate(ids)]
await asyncio.gather(*tasks)
succeeded = sum(1 for it in items if it.success)
failed = len(items) - succeeded
return BulkDownloadResult(items=items, total=len(items), succeeded=succeeded, failed=failed)
# ========= 资源cif://mp/{id} =========
@mcp.resource("cif://mp/{id}")
async def cif_resource(id: str) -> str:
"""
基于 URI 的 CIF 内容读取。资源函数无法直接拿到 Context
这里简单地新的 MPRester 会话读取结构并转 CIF仅资源路径用途
"""
# 资源中使用一次性会话(不复用 lifespan适合偶发读取 [1]
mpr = MPRester(DEFAULT_MP_API_KEY) # [3]
try:
docs = mpr.materials.search(material_ids=[id], fields=["material_id", "structure"]) # [4]
if not docs or getattr(docs[0], "structure", None) is None:
raise ValueError(f"材料 {id} 无结构数据")
struct: Structure = docs[0].structure
return struct.to(fmt="cif")
finally:
pass
def create_materials_mcp() -> FastMCP:
"""供 Starlette 挂载的工厂函数。"""
return mcp
if __name__ == "__main__":
# 独立运行(非挂载模式),使用 Streamable HTTP 传输 [1]
mcp.run(transport="streamable-http")