513 lines
19 KiB
Python
513 lines
19 KiB
Python
# materials_mp_cif_mcp.py
|
||
#
|
||
# 说明:
|
||
# - 这是一个基于 FastMCP 的“Materials Project CIF”服务器,提供:
|
||
# 1) search_cifs:使用 Materials Project 的 mp-api 进行检索(两步:summary 过滤 + materials 获取结构)
|
||
# 2) get_cif:按材料 ID 获取结构并导出 CIF 文本(pymatgen)
|
||
# 3) filter_cifs:在本地对候选进行硬筛选
|
||
# 4) download_cif_bulk:批量下载 CIF,带并发与进度
|
||
# 5) 资源:cif://mp/{id} 直接读取 CIF 文本
|
||
#
|
||
# - 依赖:pip install "mcp[cli]" mp_api pymatgen
|
||
# - API Key:
|
||
# 默认从环境变量 MP_API_KEY 读取;若未设置则回退为你提供的 Key(仅用于演示,不建议在生产中硬编码)[3]
|
||
# - MCP/Context/Lifespan 的使用方法参考 SDK 文档(工具、结构化输出、进度等)[1]
|
||
#
|
||
# 引用:
|
||
# - MPRester/MP_API_KEY 的用法与推荐的会话管理方式见 Materials Project 文档 [3]
|
||
# - summary.search 与 materials.search 的查询与 fields 限制示例见文档 [4]
|
||
# - FastMCP 的工具、资源、Context、结构化输出见 README 示例 [1]
|
||
|
||
import os
|
||
import asyncio
|
||
from dataclasses import dataclass
|
||
from typing import Any, Literal
|
||
from collections.abc import AsyncIterator
|
||
from contextlib import asynccontextmanager
|
||
import re
|
||
|
||
from pydantic import BaseModel, Field
|
||
from mp_api.client import MPRester # [3]
|
||
from pymatgen.core import Structure # 用于转 CIF
|
||
from mcp.server.fastmcp import FastMCP, Context # [1]
|
||
from mcp.server.session import ServerSession # [1]
|
||
|
||
# ========= 配置 =========
|
||
|
||
# 生产中请通过环境变量 MP_API_KEY 配置;此处为演示而提供回退到你给的 API Key(不建议硬编码)[3]
|
||
DEFAULT_MP_API_KEY = os.getenv("MP_API_KEY", "X0NqhxkFeupGy7k09HRmLS6PI4VmvTBW")
|
||
|
||
# ========= 数据模型 =========
|
||
|
||
class LatticeParameters(BaseModel):
|
||
a: float | None = Field(default=None, description="a 轴长度 (Å)")
|
||
b: float | None = Field(default=None, description="b 轴长度 (Å)")
|
||
c: float | None = Field(default=None, description="c 轴长度 (Å)")
|
||
alpha: float | None = Field(default=None, description="α 角 (度)")
|
||
beta: float | None = Field(default=None, description="β 角 (度)")
|
||
gamma: float | None = Field(default=None, description="γ 角 (度)")
|
||
|
||
|
||
class CIFCandidate(BaseModel):
|
||
id: str = Field(description="Materials Project ID,如 mp-149")
|
||
source: Literal["mp"] = Field(default="mp", description="来源库")
|
||
chemical_formula: str | None = Field(default=None, description="化学式(pretty)")
|
||
space_group: str | None = Field(default=None, description="空间群符号")
|
||
cell_parameters: LatticeParameters | None = Field(default=None, description="晶格参数")
|
||
bandgap: float | None = Field(default=None, description="带隙 (eV)")
|
||
stability_energy: float | None = Field(default=None, description="E_above_hull (eV)")
|
||
magnetic_ordering: str | None = Field(default=None, description="磁性信息")
|
||
has_cif: bool = Field(default=True, description="是否存在可导出的 CIF(只要有结构即可)")
|
||
extras: dict[str, Any] | None = Field(default=None, description="原始文档片段(健壮解析)")
|
||
|
||
|
||
class SearchQuery(BaseModel):
|
||
# 基础化学约束
|
||
include_elements: list[str] | None = Field(default=None, description="包含元素集合(如 ['Si','O'])") # [4]
|
||
exclude_elements: list[str] | None = Field(default=None, description="排除元素集合(如 ['Li'])")
|
||
formula: str | None = Field(default=None, description="化学式片段(本地二次过滤)")
|
||
|
||
# 结构与性能约束(summary 可过滤带隙与 E_above_hull)[4]
|
||
bandgap_min: float | None = None
|
||
bandgap_max: float | None = None
|
||
e_above_hull_max: float | None = None
|
||
|
||
# 空间群、晶格范围(在本地二次过滤)
|
||
space_group: str | None = None
|
||
cell_a_min: float | None = None
|
||
cell_a_max: float | None = None
|
||
cell_b_min: float | None = None
|
||
cell_b_max: float | None = None
|
||
cell_c_min: float | None = None
|
||
cell_c_max: float | None = None
|
||
|
||
magnetic_ordering: str | None = None # 本地二次过滤
|
||
|
||
page: int = Field(default=1, ge=1, description="页码(仅本地分页;MP 端点请自行扩展)")
|
||
per_page: int = Field(default=50, ge=1, le=200, description="每页数量(本地分页)")
|
||
sort_by: str | None = Field(default=None, description="排序字段(bandgap/e_above_hull等,本地排序)")
|
||
sort_order: Literal["asc", "desc"] | None = Field(default=None, description="排序方向")
|
||
|
||
|
||
class CIFText(BaseModel):
|
||
id: str
|
||
source: Literal["mp"]
|
||
text: str = Field(description="CIF 文件文本内容")
|
||
metadata: dict[str, Any] | None = Field(default=None, description="附加元数据(如结构来源)")
|
||
|
||
|
||
class FilterCriteria(BaseModel):
|
||
include_elements: list[str] | None = None
|
||
exclude_elements: list[str] | None = None
|
||
formula: str | None = None
|
||
space_group: str | None = None
|
||
cell_a_min: float | None = None
|
||
cell_a_max: float | None = None
|
||
cell_b_min: float | None = None
|
||
cell_b_max: float | None = None
|
||
cell_c_min: float | None = None
|
||
cell_c_max: float | None = None
|
||
bandgap_min: float | None = None
|
||
bandgap_max: float | None = None
|
||
e_above_hull_max: float | None = None
|
||
magnetic_ordering: str | None = None
|
||
|
||
|
||
class BulkDownloadItem(BaseModel):
|
||
id: str
|
||
source: Literal["mp"] = "mp"
|
||
success: bool
|
||
text: str | None = None
|
||
error: str | None = None
|
||
|
||
|
||
class BulkDownloadResult(BaseModel):
|
||
items: list[BulkDownloadItem]
|
||
total: int
|
||
succeeded: int
|
||
failed: int
|
||
|
||
|
||
# ========= Lifespan =========
|
||
|
||
@dataclass
|
||
class AppContext:
|
||
mpr: MPRester
|
||
api_key: str
|
||
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(_server: FastMCP) -> AsyncIterator[AppContext]:
|
||
# 推荐使用 MP_API_KEY 环境变量;此处回退到提供的 Key 字符串 [3]
|
||
api_key = DEFAULT_MP_API_KEY
|
||
mpr = MPRester(api_key) # 会话管理在 mp-api 内部 [3]
|
||
try:
|
||
yield AppContext(mpr=mpr, api_key=api_key)
|
||
finally:
|
||
# mp-api 未强制要求手动关闭;若未来提供关闭方法,可在此清理 [3]
|
||
pass
|
||
|
||
|
||
# ========= 服务器 =========
|
||
|
||
mcp = FastMCP(
|
||
name="Materials Project CIF",
|
||
instructions="通过 Materials Project API 检索并导出 CIF 文件。",
|
||
lifespan=lifespan,
|
||
streamable_http_path="/", # 便于在 Starlette 下挂载到子路径 [1]
|
||
)
|
||
|
||
|
||
# ========= 工具实现 =========
|
||
|
||
def _match_elements_set(text: str | None, required: list[str] | None, excluded: list[str] | None) -> bool:
|
||
if text is None:
|
||
return not required and not excluded
|
||
s = text.upper()
|
||
if required:
|
||
for el in required:
|
||
if el.upper() not in s:
|
||
return False
|
||
if excluded:
|
||
for el in excluded:
|
||
if el.upper() in s:
|
||
return False
|
||
return True
|
||
|
||
|
||
def _lattice_from_structure(struct: Structure | None) -> LatticeParameters | None:
|
||
if struct is None:
|
||
return None
|
||
lat = struct.lattice
|
||
return LatticeParameters(
|
||
a=float(lat.a),
|
||
b=float(lat.b),
|
||
c=float(lat.c),
|
||
alpha=float(lat.alpha),
|
||
beta=float(lat.beta),
|
||
gamma=float(lat.gamma),
|
||
)
|
||
|
||
|
||
def _sg_symbol_from_doc(doc: Any) -> str | None:
|
||
"""
|
||
尝试从 materials.search 返回的文档中提取空间群符号。
|
||
MP 文档中 symmetry 通常包含 spacegroup 信息(不同版本字段路径可能变化)[4]
|
||
"""
|
||
try:
|
||
sym = getattr(doc, "symmetry", None)
|
||
if sym and getattr(sym, "spacegroup", None):
|
||
sg = sym.spacegroup
|
||
# 常见字段:symbol 或 international
|
||
return getattr(sg, "symbol", None) or getattr(sg, "international", None)
|
||
except Exception:
|
||
pass
|
||
return None
|
||
|
||
|
||
async def _to_thread(fn, *args, **kwargs):
|
||
return await asyncio.to_thread(fn, *args, **kwargs)
|
||
|
||
|
||
def _parse_elements_from_formula(formula: str) -> list[str]:
|
||
"""从化学式中解析元素符号(简易版)。"""
|
||
if not formula:
|
||
return []
|
||
elems = re.findall(r"[A-Z][a-z]?", formula)
|
||
seen, ordered = set(), []
|
||
for e in elems:
|
||
if e not in seen:
|
||
seen.add(e)
|
||
ordered.append(e)
|
||
return ordered
|
||
|
||
@mcp.tool()
|
||
async def search_cifs(
|
||
query: SearchQuery,
|
||
ctx: Context[ServerSession, AppContext],
|
||
) -> list[CIFCandidate]:
|
||
"""
|
||
通过 summary 做初筛(限制返回数量),再用 materials 拉取当前页的结构与对称性。
|
||
修复点:
|
||
- 仅请求当前页(chunk_size/per_page + num_chunks=1),避免全库下载超时
|
||
- 绝不在返回值中包含 mp-api 文档对象或 Structure,extras 只放可序列化轻量字段
|
||
"""
|
||
app = ctx.request_context.lifespan_context
|
||
mpr = app.mpr
|
||
|
||
# 将 formula 自动转为元素集合(若用户未显式提供)
|
||
include_elements = query.include_elements or _parse_elements_from_formula(query.formula or "")
|
||
|
||
# 至少要有一个约束,避免全库扫描
|
||
if not include_elements and not query.exclude_elements and query.bandgap_min is None \
|
||
and query.bandgap_max is None and query.e_above_hull_max is None:
|
||
raise ValueError("请至少提供一个过滤条件,例如 include_elements、bandgap 或 energy_above_hull 上限,以避免超时。")
|
||
|
||
# 1) summary 初筛:只取“当前页”的文档(轻字段)
|
||
summary_kwargs: dict[str, Any] = {}
|
||
if include_elements:
|
||
summary_kwargs["elements"] = include_elements # MP 支持按元素集合检索
|
||
if query.exclude_elements:
|
||
summary_kwargs["exclude_elements"] = query.exclude_elements
|
||
if query.bandgap_min is not None or query.bandgap_max is not None:
|
||
summary_kwargs["band_gap"] = (
|
||
float(query.bandgap_min) if query.bandgap_min is not None else None,
|
||
float(query.bandgap_max) if query.bandgap_max is not None else None,
|
||
)
|
||
if query.e_above_hull_max is not None:
|
||
summary_kwargs["energy_above_hull"] = (None, float(query.e_above_hull_max))
|
||
|
||
fields_summary = [
|
||
"material_id",
|
||
"formula_pretty",
|
||
"band_gap",
|
||
"energy_above_hull",
|
||
"ordering",
|
||
]
|
||
|
||
# 关键:限制 summary 返回规模
|
||
try:
|
||
summary_docs = await asyncio.to_thread(
|
||
mpr.materials.summary.search,
|
||
fields=fields_summary,
|
||
chunk_size=query.per_page, # 每块(页)大小
|
||
num_chunks=1, # 只取一块 => 只取当前页数量
|
||
**summary_kwargs,
|
||
)
|
||
except TypeError:
|
||
# 如果你的 mp-api 版本不支持上述参数,提示升级,避免一次性抓回十几万条导致超时
|
||
raise RuntimeError(
|
||
"当前 mp-api 版本不支持 chunk_size/num_chunks,请升级 mp-api(pip install -U mp_api)。"
|
||
)
|
||
|
||
if not summary_docs:
|
||
await ctx.debug("summary 检索结果为空")
|
||
return []
|
||
|
||
# 本地排序(如需)
|
||
results_all: list[CIFCandidate] = []
|
||
for sdoc in summary_docs:
|
||
e_hull = getattr(sdoc, "energy_above_hull", None)
|
||
bandgap = getattr(sdoc, "band_gap", None)
|
||
ordering = getattr(sdoc, "ordering", None)
|
||
|
||
# 暂不放 structure/doc 到 extras,避免序列化错误
|
||
results_all.append(
|
||
CIFCandidate(
|
||
id=sdoc.material_id,
|
||
source="mp",
|
||
chemical_formula=getattr(sdoc, "formula_pretty", None),
|
||
space_group=None,
|
||
cell_parameters=None,
|
||
bandgap=float(bandgap) if bandgap is not None else None,
|
||
stability_energy=float(e_hull) if e_hull is not None else None,
|
||
magnetic_ordering=ordering,
|
||
has_cif=True,
|
||
extras={
|
||
"summary": {
|
||
"material_id": sdoc.material_id,
|
||
"formula_pretty": getattr(sdoc, "formula_pretty", None),
|
||
"band_gap": float(bandgap) if bandgap is not None else None,
|
||
"energy_above_hull": float(e_hull) if e_hull is not None else None,
|
||
"ordering": ordering,
|
||
}
|
||
},
|
||
)
|
||
)
|
||
|
||
# 本地分页:summary 已限制返回条数,这里只做安全截断和排序
|
||
if query.sort_by:
|
||
reverse = query.sort_order == "desc"
|
||
key_map = {
|
||
"bandgap": lambda c: (c.bandgap if c.bandgap is not None else float("inf")),
|
||
"e_above_hull": lambda c: (c.stability_energy if c.stability_energy is not None else float("inf")),
|
||
"energy_above_hull": lambda c: (c.stability_energy if c.stability_energy is not None else float("inf")),
|
||
}
|
||
if query.sort_by in key_map:
|
||
results_all.sort(key=key_map[query.sort_by], reverse=reverse)
|
||
|
||
page_slice = results_all[: query.per_page]
|
||
page_ids = [c.id for c in page_slice]
|
||
|
||
# 2) 只为当前页材料获取结构与对称性
|
||
if not page_ids:
|
||
return []
|
||
|
||
fields_materials = ["material_id", "structure", "symmetry"]
|
||
materials_docs = await asyncio.to_thread(
|
||
mpr.materials.search,
|
||
material_ids=page_ids,
|
||
fields=fields_materials,
|
||
)
|
||
mat_by_id = {doc.material_id: doc for doc in materials_docs}
|
||
|
||
# 合并结构/空间群,仍然不返回不可序列化对象
|
||
merged: list[CIFCandidate] = []
|
||
for c in page_slice:
|
||
mdoc = mat_by_id.get(c.id)
|
||
struct = getattr(mdoc, "structure", None) if mdoc else None
|
||
|
||
# 空间群符号
|
||
sg = None
|
||
try:
|
||
sym = getattr(mdoc, "symmetry", None)
|
||
if sym and getattr(sym, "spacegroup", None):
|
||
sgobj = sym.spacegroup
|
||
sg = getattr(sgobj, "symbol", None) or getattr(sgobj, "international", None)
|
||
except Exception:
|
||
pass
|
||
|
||
c.space_group = sg
|
||
c.cell_parameters = _lattice_from_structure(struct)
|
||
c.has_cif = struct is not None
|
||
# 保持 extras 为轻量 JSON
|
||
c.extras = {
|
||
**(c.extras or {}),
|
||
"materials": {"has_structure": struct is not None},
|
||
}
|
||
merged.append(c)
|
||
|
||
await ctx.debug(f"返回 {len(merged)} 条(已限制每页 {query.per_page} 条)")
|
||
return merged
|
||
|
||
@mcp.tool()
|
||
async def get_cif(
|
||
id: str,
|
||
ctx: Context[ServerSession, AppContext],
|
||
) -> CIFText:
|
||
"""
|
||
使用 materials.search 获取结构并导出 CIF 文本。
|
||
- fields=["structure"] 以提升速度 [4]
|
||
- 结构到 CIF 通过 pymatgen 的 Structure.to(fmt="cif")
|
||
"""
|
||
app = ctx.request_context.lifespan_context
|
||
mpr = app.mpr
|
||
|
||
await ctx.info(f"获取结构并导出 CIF:{id}")
|
||
docs = await _to_thread(mpr.materials.search, material_ids=[id], fields=["material_id", "structure"]) # [4]
|
||
if not docs:
|
||
raise ValueError(f"未找到材料:{id}")
|
||
|
||
doc = docs[0]
|
||
struct: Structure | None = getattr(doc, "structure", None)
|
||
if struct is None:
|
||
raise ValueError(f"材料 {id} 无结构数据,无法导出 CIF")
|
||
|
||
cif_text = struct.to(fmt="cif") # 直接输出 CIF 文本
|
||
return CIFText(id=id, source="mp", text=cif_text, metadata={"from": "materials.search"})
|
||
|
||
|
||
@mcp.tool()
|
||
async def filter_cifs(
|
||
candidates: list[CIFCandidate],
|
||
criteria: FilterCriteria,
|
||
ctx: Context[ServerSession, AppContext],
|
||
) -> list[CIFCandidate]:
|
||
def in_range(v: float | None, vmin: float | None, vmax: float | None) -> bool:
|
||
if v is None:
|
||
return vmin is None and vmax is None
|
||
if vmin is not None and v < vmin:
|
||
return False
|
||
if vmax is not None and v > vmax:
|
||
return False
|
||
return True
|
||
|
||
filtered: list[CIFCandidate] = []
|
||
for c in candidates:
|
||
# 元素集合近似匹配(基于化学式字符串)
|
||
if not _match_elements_set(c.chemical_formula, criteria.include_elements, criteria.exclude_elements):
|
||
continue
|
||
# 化学式包含判断
|
||
if criteria.formula:
|
||
if not c.chemical_formula or criteria.formula.upper() not in c.chemical_formula.upper():
|
||
continue
|
||
# 空间群
|
||
if criteria.space_group:
|
||
spg = (c.space_group or "").lower()
|
||
if criteria.space_group.lower() not in spg:
|
||
continue
|
||
# 晶格范围
|
||
cp = c.cell_parameters or LatticeParameters()
|
||
if not in_range(cp.a, criteria.cell_a_min, criteria.cell_a_max):
|
||
continue
|
||
if not in_range(cp.b, criteria.cell_b_min, criteria.cell_b_max):
|
||
continue
|
||
if not in_range(cp.c, criteria.cell_c_min, criteria.cell_c_max):
|
||
continue
|
||
# 带隙与稳定性
|
||
if not in_range(c.bandgap, criteria.bandgap_min, criteria.bandgap_max):
|
||
continue
|
||
if criteria.e_above_hull_max is not None:
|
||
if c.stability_energy is None or c.stability_energy > criteria.e_above_hull_max:
|
||
continue
|
||
# 磁性
|
||
if criteria.magnetic_ordering:
|
||
mo = "" # 当前未填充,保留兼容
|
||
if criteria.magnetic_ordering.upper() not in mo.upper():
|
||
continue
|
||
|
||
filtered.append(c)
|
||
|
||
await ctx.debug(f"筛选后:{len(filtered)} / 原始 {len(candidates)}")
|
||
return filtered
|
||
|
||
|
||
@mcp.tool()
|
||
async def download_cif_bulk(
|
||
ids: list[str],
|
||
concurrency: int = 4,
|
||
ctx: Context[ServerSession, AppContext] | None = None,
|
||
) -> BulkDownloadResult:
|
||
semaphore = asyncio.Semaphore(max(1, concurrency))
|
||
items: list[BulkDownloadItem] = []
|
||
|
||
async def worker(mid: str, idx: int) -> None:
|
||
nonlocal items
|
||
try:
|
||
async with semaphore:
|
||
# 复用当前 ctx
|
||
assert ctx is not None
|
||
cif = await get_cif(id=mid, ctx=ctx)
|
||
items.append(BulkDownloadItem(id=mid, success=True, text=cif.text))
|
||
except Exception as e:
|
||
items.append(BulkDownloadItem(id=mid, success=False, error=str(e)))
|
||
if ctx is not None:
|
||
progress = (idx + 1) / len(ids) if len(ids) > 0 else 1.0
|
||
await ctx.report_progress(progress=progress, total=1.0, message=f"{idx + 1}/{len(ids)}") # [1]
|
||
|
||
tasks = [worker(mid, i) for i, mid in enumerate(ids)]
|
||
await asyncio.gather(*tasks)
|
||
|
||
succeeded = sum(1 for it in items if it.success)
|
||
failed = len(items) - succeeded
|
||
return BulkDownloadResult(items=items, total=len(items), succeeded=succeeded, failed=failed)
|
||
|
||
|
||
# ========= 资源:cif://mp/{id} =========
|
||
|
||
@mcp.resource("cif://mp/{id}")
|
||
async def cif_resource(id: str) -> str:
|
||
"""
|
||
基于 URI 的 CIF 内容读取。资源函数无法直接拿到 Context,
|
||
这里简单地新的 MPRester 会话读取结构并转 CIF(仅资源路径用途)。
|
||
"""
|
||
# 资源中使用一次性会话(不复用 lifespan),适合偶发读取 [1]
|
||
mpr = MPRester(DEFAULT_MP_API_KEY) # [3]
|
||
try:
|
||
docs = mpr.materials.search(material_ids=[id], fields=["material_id", "structure"]) # [4]
|
||
if not docs or getattr(docs[0], "structure", None) is None:
|
||
raise ValueError(f"材料 {id} 无结构数据")
|
||
struct: Structure = docs[0].structure
|
||
return struct.to(fmt="cif")
|
||
finally:
|
||
pass
|
||
|
||
|
||
def create_materials_mcp() -> FastMCP:
|
||
"""供 Starlette 挂载的工厂函数。"""
|
||
return mcp
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 独立运行(非挂载模式),使用 Streamable HTTP 传输 [1]
|
||
mcp.run(transport="streamable-http") |