Files
solidstate-tools/mcp/softBV.py
2025-10-16 14:57:27 +08:00

464 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# softbv_mcp.py
#
# 在远程服务器上激活 softBV 环境并执行计算(支持 --md 与 --gen-cube 两个专用工具)。
# - 生命周期:建立 SSH 连接并注入上下文
# - 激活环境source /cluster/home/koko125/script/softBV.sh
# - 工作目录:/cluster/home/koko125/sandbox
# - 可执行文件:/cluster/home/koko125/tool/softBV-GUI_linux/bin/softBV.x
#
# 依赖:
# pip install "mcp[cli]" asyncssh pydantic
#
# 用法Starlette 挂载示例见你现有 main.py导入 create_softbv_mcp 即可):
# from softbv_mcp import create_softbv_mcp
# softbv_mcp = create_softbv_mcp()
# Mount("/softbv", app=softbv_mcp.streamable_http_app())
#
# 可通过环境变量覆盖连接信息:
# REMOTE_HOST, REMOTE_USER, PRIVATE_KEY_PATH, SOFTBV_PROFILE, SOFTBV_BIN, DEFAULT_WORKDIR
import os
import posixpath
import asyncio
from dataclasses import dataclass
from socket import socket
from typing import Any
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
import asyncssh
from pydantic import BaseModel, Field
from mcp.server.fastmcp import FastMCP, Context
from mcp.server.session import ServerSession
from lifespan import REMOTE_HOST, REMOTE_USER, PRIVATE_KEY_PATH
def shell_quote(arg: str) -> str:
"""安全地把字符串作为单个 shell 参数POSIX"""
return "'" + str(arg).replace("'", "'\"'\"'") + "'"
# 如果你已经定义了这些常量与路径,可保持复用;也可按需改为你自己的配置源
# 固定的 softBV 环境信息(可用环境变量覆盖)
SOFTBV_PROFILE = os.getenv("SOFTBV_PROFILE", "/cluster/home/koko125/script/softBV.sh")
SOFTBV_BIN = os.getenv("SOFTBV_BIN", "/cluster/home/koko125/tool/softBV-GUI_linux/bin/softBV.x")
DEFAULT_WORKDIR = os.getenv("DEFAULT_WORKDIR", "/cluster/home/koko125/sandbox")
@dataclass
class SoftBVContext:
ssh_connection: asyncssh.SSHClientConnection
workdir: str
profile: str
bin_path: str
@asynccontextmanager
async def softbv_lifespan(_server) -> AsyncIterator[SoftBVContext]:
"""
FastMCP 生命周期:建立 SSH 连接并注入 softBV 上下文。
- 不再做额外的 DNS 解析或自定义异步步骤,避免 socket.SOCK_STREAM 的环境覆盖问题
- 仅负责连接与清理;工具中通过 ctx.request_context.lifespan_context 访问该上下文 [1]
"""
# 允许用环境变量覆盖连接信息
host = os.getenv("REMOTE_HOST", REMOTE_HOST)
user = os.getenv("REMOTE_USER", REMOTE_USER)
key_path = os.getenv("PRIVATE_KEY_PATH", PRIVATE_KEY_PATH)
conn: asyncssh.SSHClientConnection | None = None
try:
conn = await asyncssh.connect(
host,
username=user,
client_keys=[key_path],
known_hosts=None, # 如需主机指纹校验,可移除此参数
connect_timeout=15, # 避免长时间挂起
)
yield SoftBVContext(
ssh_connection=conn,
workdir=DEFAULT_WORKDIR,
profile=SOFTBV_PROFILE,
bin_path=SOFTBV_BIN,
)
finally:
if conn:
conn.close()
await conn.wait_closed()
async def run_in_softbv_env(
conn: asyncssh.SSHClientConnection,
profile_path: str,
cmd: str,
cwd: str | None = None,
check: bool = True,
) -> asyncssh.SSHCompletedProcess:
"""
在远端 bash 会话中激活 softBV 环境并执行 cmd。
如提供 cwd则先 cd 到该目录。
"""
parts = []
if cwd:
parts.append(f"cd {shell_quote(cwd)}")
parts.append(f"source {shell_quote(profile_path)}")
parts.append(cmd)
composite = "; ".join(parts)
full = f"bash -lc {shell_quote(composite)}"
return await conn.run(full, check=check)
# ===== MCP 服务器 =====
mcp = FastMCP(
name="softBV Tools",
instructions="在远程服务器上激活 softBV 环境并执行相关计算的工具集。",
lifespan=softbv_lifespan,
streamable_http_path="/",
)
# ===== 辅助:列目录用于识别新生成文件 =====
async def _listdir(conn: asyncssh.SSHClientConnection, path: str) -> list[str]:
async with conn.start_sftp_client() as sftp:
try:
return await sftp.listdir(path)
except Exception:
return []
# ===== 结构化输入模型:--md =====
class SoftBVMDArgs(BaseModel):
input_cif: str = Field(description="远程 CIF 文件路径(相对或绝对),作为 --md 的输入")
# 位置参数按帮助中的顺序None 表示不提供,让程序使用默认)
type: str | None = Field(default=None, description="conducting ion 类型(例如 'Li'")
os: int | None = Field(default=None, description="conducting ion 氧化态(整数)")
sf: float | None = Field(default=None, description="screening factor非正值使用默认")
temperature: float | None = Field(default=None, description="温度 K非正值默认 300")
t_end: float | None = Field(default=None, description="生产时间 ps非正值默认 10.0")
t_equil: float | None = Field(default=None, description="平衡时间 ps非正值默认 2.0")
dt: float | None = Field(default=None, description="时间步长 ps非正值默认 0.001")
t_log: float | None = Field(default=None, description="采样间隔 ps非正值每 100 步)")
cwd: str | None = Field(default=None, description="远程工作目录(默认使用生命周期中的 workdir")
def _build_md_cmd(bin_path: str, args: SoftBVMDArgs, workdir: str) -> str:
input_abs = args.input_cif
if not input_abs.startswith("/"):
input_abs = posixpath.normpath(posixpath.join(workdir, args.input_cif))
parts: list[str] = [shell_quote(bin_path), shell_quote("--md"), shell_quote(input_abs)]
for val in [args.type, args.os, args.sf, args.temperature, args.t_end, args.t_equil, args.dt, args.t_log]:
if val is not None:
parts.append(shell_quote(str(val)))
return " ".join(parts)
# ===== 结构化输入模型:--gen-cube =====
class SoftBVGenCubeArgs(BaseModel):
input_cif: str = Field(description="远程 CIF 文件路径(相对或绝对),作为 --gen-cube 的输入")
type: str | None = Field(default=None, description="conducting ion 类型(如 'Li'")
os: int | None = Field(default=None, description="conducting ion 氧化态(整数)")
sf: float | None = Field(default=None, description="screening factor非正值使用默认")
resolution: float | None = Field(default=None, description="体素分辨率(默认约 0.1")
ignore_conducting_ion: bool = Field(default=False, description="flag:ignore_conducting_ion")
periodic: bool = Field(default=True, description="flag:periodic默认 True")
output_name: str | None = Field(default=None, description="输出文件名前缀(可选)")
cwd: str | None = Field(default=None, description="远程工作目录(默认使用生命周期中的 workdir")
def _build_gen_cube_cmd(bin_path: str, args: SoftBVGenCubeArgs, workdir: str) -> str:
input_abs = args.input_cif
if not input_abs.startswith("/"):
input_abs = posixpath.normpath(posixpath.join(workdir, args.input_cif))
parts: list[str] = [shell_quote(bin_path), shell_quote("--gen-cube"), shell_quote(input_abs)]
for val in [args.type, args.os, args.sf, args.resolution]:
if val is not None:
parts.append(shell_quote(str(val)))
if args.ignore_conducting_ion:
parts.append(shell_quote("--flag:ignore_conducting_ion"))
if args.periodic:
parts.append(shell_quote("--flag:periodic"))
if args.output_name:
parts.append(shell_quote(args.output_name))
return " ".join(parts)
# ===== 工具:环境信息检查 =====
# 工具:环境信息检查(修复版,避免超时)
@mcp.tool()
async def softbv_info(ctx: Context[ServerSession, SoftBVContext]) -> dict[str, Any]:
"""
快速自检:
- SFTP 检查工作目录、激活脚本、二进制是否存在/可执行(无需运行 softBV.x
- 激活环境后仅输出标记与当前工作目录,避免长输出或阻塞
"""
app = ctx.request_context.lifespan_context
conn = app.ssh_connection
# 1) 通过 SFTP 快速检查文件与目录状态(不会长时间阻塞)
def stat_path_safe(path: str) -> dict[str, Any]:
return {"exists": False, "is_exec": False, "size": None}
workdir_info = stat_path_safe(app.workdir)
profile_info = stat_path_safe(app.profile)
bin_info = stat_path_safe(app.bin_path)
try:
async with conn.start_sftp_client() as sftp:
# workdir
try:
attrs = await sftp.stat(app.workdir)
workdir_info["exists"] = True
workdir_info["size"] = int(attrs.size or 0)
except Exception:
pass
# profile
try:
attrs = await sftp.stat(app.profile)
profile_info["exists"] = True
profile_info["size"] = int(attrs.size or 0)
perms = int(attrs.permissions or 0)
profile_info["is_exec"] = bool(perms & 0o111)
except Exception:
pass
# bin
try:
attrs = await sftp.stat(app.bin_path)
bin_info["exists"] = True
bin_info["size"] = int(attrs.size or 0)
perms = int(attrs.permissions or 0)
bin_info["is_exec"] = bool(perms & 0o111)
except Exception:
pass
except Exception as e:
await ctx.warning(f"SFTP 检查失败: {e}")
# 2) 激活环境并做极简命令(避免 softBV.x --help 的长输出)
# 仅返回当前用户、PWD 与二进制可执行判断;不实际运行 softBV.x
cmd = "echo __SOFTBV_READY__ && echo $USER && pwd && (test -x " + shell_quote(app.bin_path) + " && echo __BIN_OK__ || echo __BIN_NOT_EXEC__)"
proc = await run_in_softbv_env(conn, app.profile, cmd=cmd, cwd=app.workdir, check=False)
# 解析输出行
lines = proc.stdout.splitlines() if proc.stdout else []
ready = "__SOFTBV_READY__" in lines
user = None
pwd = None
bin_ok = "__BIN_OK__" in lines
# 尝试定位 user/pwdready 之后的两行)
if ready:
idx = lines.index("__SOFTBV_READY__")
if len(lines) > idx + 1:
user = lines[idx + 1].strip()
if len(lines) > idx + 2:
pwd = lines[idx + 2].strip()
result = {
"host": os.getenv("REMOTE_HOST", REMOTE_HOST),
"user": os.getenv("REMOTE_USER", REMOTE_USER),
"workdir": app.workdir,
"profile": app.profile,
"bin_path": app.bin_path,
"sftp_check": {
"workdir": workdir_info,
"profile": profile_info,
"bin": bin_info,
},
"activate_ready": ready,
"pwd": pwd,
"bin_is_executable": bin_ok or bin_info["is_exec"],
"exit_status": proc.exit_status,
"stderr_head": "\n".join(proc.stderr.splitlines()[:10]) if proc.stderr else "",
}
# 友好日志
if not ready:
await ctx.warning("softBV 环境未就绪(可能 source 脚本路径问题或权限不足)")
if not result["bin_is_executable"]:
await ctx.warning("softBV 二进制不可执行或不存在,请检查 bin_path 与权限chmod +x")
if proc.exit_status != 0 and not proc.stderr:
await ctx.debug("命令非零退出但无 stderr可能是某些子测试返回非零导致")
return result
# ===== 工具:--md =====
@mcp.tool()
async def softbv_md(req: SoftBVMDArgs, ctx: Context[ServerSession, SoftBVContext]) -> dict[str, Any]:
"""
执行 softBV.x --md返回结构化结果
- cmd/cwd/exit_status/stdout/stderr
- new_files执行后新增文件列表便于定位输出
"""
app = ctx.request_context.lifespan_context
conn = app.ssh_connection
workdir = req.cwd or app.workdir
cmd = _build_md_cmd(app.bin_path, req, workdir)
await ctx.info(f"softBV --md 执行: {cmd} (cwd={workdir})")
pre_list = await _listdir(conn, workdir)
proc = await run_in_softbv_env(conn, app.profile, cmd=cmd, cwd=workdir, check=False)
post_list = await _listdir(conn, workdir)
new_files = sorted(set(post_list) - set(pre_list))
if proc.exit_status == 0:
await ctx.debug(f"--md 成功,新文件 {len(new_files)}")
else:
await ctx.warning(f"--md 非零退出: {proc.exit_status}")
return {
"cmd": cmd,
"cwd": workdir,
"exit_status": proc.exit_status,
"stdout": proc.stdout,
"stderr": proc.stderr,
"new_files": new_files,
}
async def run_in_softbv_env_stream(
conn: asyncssh.SSHClientConnection,
profile_path: str,
cmd: str,
cwd: str | None = None,
) -> asyncssh.SSHClientProcess:
parts = []
if cwd:
parts.append(f"cd {shell_quote(cwd)}")
parts.append(f"source {shell_quote(profile_path)} >/dev/null 2>&1 || true")
parts.append(cmd)
composite = "; ".join(parts)
full = f"bash -lc {shell_quote(composite)}"
# 不阻塞,返回进程句柄
proc = await conn.create_process(full)
return proc
# 轮询目录,识别新生成文件
async def _listdir(conn: asyncssh.SSHClientConnection, path: str) -> list[str]:
async with conn.start_sftp_client() as sftp:
try:
return await sftp.listdir(path)
except Exception:
return []
# 轮询日志文件大小(作为心跳/粗略进度依据)
async def _stat_size(conn: asyncssh.SSHClientConnection, path: str) -> int | None:
async with conn.start_sftp_client() as sftp:
try:
attrs = await sftp.stat(path)
return int(attrs.size or 0)
except Exception:
return None
# 构造 gen-cube 命令(保持你之前的参数拼接逻辑)
def _build_gen_cube_cmd(bin_path: str, args: SoftBVGenCubeArgs, workdir: str, log_path: str | None = None) -> str:
input_abs = args.input_cif
if not input_abs.startswith("/"):
input_abs = posixpath.normpath(posixpath.join(workdir, args.input_cif))
parts: list[str] = [shell_quote(bin_path), shell_quote("--gen-cube"), shell_quote(input_abs)]
for val in [args.type, args.os, args.sf, args.resolution]:
if val is not None:
parts.append(shell_quote(str(val)))
if args.ignore_conducting_ion:
parts.append(shell_quote("--flag:ignore_conducting_ion"))
if args.periodic:
parts.append(shell_quote("--flag:periodic"))
if args.output_name:
parts.append(shell_quote(args.output_name))
cmd = " ".join(parts)
# 将输出重定向到日志,便于轮询
if log_path:
cmd = f"{cmd} > {shell_quote(log_path)} 2>&1"
return cmd
# 修复版:长时运行的 softbv_gen_cube带心跳与超时保护
@mcp.tool()
async def softbv_gen_cube(req: SoftBVGenCubeArgs, ctx: Context[ServerSession, SoftBVContext]) -> dict[str, Any]:
"""
执行 softBV.x --gen-cube支持长任务心跳避免在 <25min 内被客户端强制超时。
- 每隔 10s 上报一次进度(心跳),包含已用时/日志大小/新增文件数
- 结束后返回 stdout_head来自日志文件片段、stderr_head如有、exit_status、新增文件
"""
app = ctx.request_context.lifespan_context
conn = app.ssh_connection
workdir = req.cwd or app.workdir
# 预先记录目录内容,用于结束后差集
before = await _listdir(conn, workdir)
# 远端日志文件路径(按时间戳命名)
import time
log_name = f"softbv_gencube_{int(time.time())}.log"
log_path = posixpath.join(workdir, log_name)
# 启动长任务,不阻塞当前协程
cmd = _build_gen_cube_cmd(app.bin_path, req, workdir, log_path=log_path)
await ctx.info(f"启动 --gen-cube: {cmd}")
proc = await run_in_softbv_env_stream(conn, app.profile, cmd=cmd, cwd=workdir)
# 心跳循环:直到进程退出
start_ts = time.time()
heartbeat_sec = 10 # 每 10 秒发送一次心跳
max_guard_min = 60 # 保险上限(服务端不主动终止;如客户端有限制可调大)
try:
while True:
# 进程是否已退出
if proc.exit_status is not None:
break
# 采集状态:已用时、日志大小、新增文件数
elapsed = time.time() - start_ts
log_size = await _stat_size(conn, log_path)
now_files = await _listdir(conn, workdir)
new_files_count = len(set(now_files) - set(before))
# 这里无法获知真实百分比,使用“心跳/已用时提示”
await ctx.report_progress(
progress=min(elapsed / (25 * 60), 0.99), # 以 25min 为目标上限做近似刻度
total=1.0,
message=f"gen-cube 运行中: 已用时 {int(elapsed)}s, 日志 {log_size or 0}B, 新文件 {new_files_count}",
)
# 避免客户端超时:持续心跳即可。[1]
await asyncio.sleep(heartbeat_sec)
# 简易守护:超过 max_guard_min 仍未结束也不强制中断(由远端决定)
if elapsed > max_guard_min * 60:
await ctx.warning("任务已超过守护上限时间,仍在运行(未强制中断)。如需更长时间,请增大上限。")
# 不 break继续等待交由远端任务完成
finally:
# 等待进程真正结束(如果已结束,这里是快速返回)
await proc.wait()
# 结束后采集结果
exit_status = proc.exit_status
after = await _listdir(conn, workdir)
new_files = sorted(set(after) - set(before))
# 读取日志片段(头/尾),帮助定位输出
async with conn.start_sftp_client() as sftp:
head = ""
tail = ""
try:
async with sftp.open(log_path, "rb") as f:
content = await f.read()
text = content.decode("utf-8", errors="replace")
lines = text.splitlines()
head = "\n".join(lines[:40])
tail = "\n".join(lines[-40:])
except Exception:
pass
# 输出结构化结果
result = {
"cmd": cmd,
"cwd": workdir,
"exit_status": exit_status,
"log_file": log_path,
"stdout_head": head, # 代替一次性 stdout避免大输出
"stderr_head": "", # 统一日志到文件stderr_head 可留空
"new_files": new_files,
"elapsed_sec": int(time.time() - start_ts),
}
if exit_status == 0:
await ctx.info(f"gen-cube 完成,用时 {result['elapsed_sec']}s新文件 {len(new_files)}")
else:
await ctx.warning(f"gen-cube 退出码 {exit_status},请查看日志 {log_path}")
return result
def create_softbv_mcp() -> FastMCP:
"""供外部Starlette导入的工厂函数。"""
return mcp
if __name__ == "__main__":
mcp.run(transport="streamable-http")