464 lines
19 KiB
Python
464 lines
19 KiB
Python
# softbv_mcp.py
|
||
#
|
||
# 在远程服务器上激活 softBV 环境并执行计算(支持 --md 与 --gen-cube 两个专用工具)。
|
||
# - 生命周期:建立 SSH 连接并注入上下文
|
||
# - 激活环境:source /cluster/home/koko125/script/softBV.sh
|
||
# - 工作目录:/cluster/home/koko125/sandbox
|
||
# - 可执行文件:/cluster/home/koko125/tool/softBV-GUI_linux/bin/softBV.x
|
||
#
|
||
# 依赖:
|
||
# pip install "mcp[cli]" asyncssh pydantic
|
||
#
|
||
# 用法(Starlette 挂载示例见你现有 main.py,导入 create_softbv_mcp 即可):
|
||
# from softbv_mcp import create_softbv_mcp
|
||
# softbv_mcp = create_softbv_mcp()
|
||
# Mount("/softbv", app=softbv_mcp.streamable_http_app())
|
||
#
|
||
# 可通过环境变量覆盖连接信息:
|
||
# REMOTE_HOST, REMOTE_USER, PRIVATE_KEY_PATH, SOFTBV_PROFILE, SOFTBV_BIN, DEFAULT_WORKDIR
|
||
|
||
import os
|
||
import posixpath
|
||
import asyncio
|
||
from dataclasses import dataclass
|
||
from socket import socket
|
||
from typing import Any
|
||
from collections.abc import AsyncIterator
|
||
from contextlib import asynccontextmanager
|
||
|
||
import asyncssh
|
||
from pydantic import BaseModel, Field
|
||
|
||
from mcp.server.fastmcp import FastMCP, Context
|
||
from mcp.server.session import ServerSession
|
||
|
||
from lifespan import REMOTE_HOST, REMOTE_USER, PRIVATE_KEY_PATH
|
||
|
||
def shell_quote(arg: str) -> str:
|
||
"""安全地把字符串作为单个 shell 参数(POSIX)。"""
|
||
return "'" + str(arg).replace("'", "'\"'\"'") + "'"
|
||
|
||
# 如果你已经定义了这些常量与路径,可保持复用;也可按需改为你自己的配置源
|
||
|
||
|
||
# 固定的 softBV 环境信息(可用环境变量覆盖)
|
||
SOFTBV_PROFILE = os.getenv("SOFTBV_PROFILE", "/cluster/home/koko125/script/softBV.sh")
|
||
SOFTBV_BIN = os.getenv("SOFTBV_BIN", "/cluster/home/koko125/tool/softBV-GUI_linux/bin/softBV.x")
|
||
DEFAULT_WORKDIR = os.getenv("DEFAULT_WORKDIR", "/cluster/home/koko125/sandbox")
|
||
|
||
@dataclass
|
||
class SoftBVContext:
|
||
ssh_connection: asyncssh.SSHClientConnection
|
||
workdir: str
|
||
profile: str
|
||
bin_path: str
|
||
|
||
@asynccontextmanager
|
||
async def softbv_lifespan(_server) -> AsyncIterator[SoftBVContext]:
|
||
"""
|
||
FastMCP 生命周期:建立 SSH 连接并注入 softBV 上下文。
|
||
- 不再做额外的 DNS 解析或自定义异步步骤,避免 socket.SOCK_STREAM 的环境覆盖问题
|
||
- 仅负责连接与清理;工具中通过 ctx.request_context.lifespan_context 访问该上下文 [1]
|
||
"""
|
||
# 允许用环境变量覆盖连接信息
|
||
host = os.getenv("REMOTE_HOST", REMOTE_HOST)
|
||
user = os.getenv("REMOTE_USER", REMOTE_USER)
|
||
key_path = os.getenv("PRIVATE_KEY_PATH", PRIVATE_KEY_PATH)
|
||
|
||
conn: asyncssh.SSHClientConnection | None = None
|
||
try:
|
||
conn = await asyncssh.connect(
|
||
host,
|
||
username=user,
|
||
client_keys=[key_path],
|
||
known_hosts=None, # 如需主机指纹校验,可移除此参数
|
||
connect_timeout=15, # 避免长时间挂起
|
||
)
|
||
yield SoftBVContext(
|
||
ssh_connection=conn,
|
||
workdir=DEFAULT_WORKDIR,
|
||
profile=SOFTBV_PROFILE,
|
||
bin_path=SOFTBV_BIN,
|
||
)
|
||
finally:
|
||
if conn:
|
||
conn.close()
|
||
await conn.wait_closed()
|
||
async def run_in_softbv_env(
|
||
conn: asyncssh.SSHClientConnection,
|
||
profile_path: str,
|
||
cmd: str,
|
||
cwd: str | None = None,
|
||
check: bool = True,
|
||
) -> asyncssh.SSHCompletedProcess:
|
||
"""
|
||
在远端 bash 会话中激活 softBV 环境并执行 cmd。
|
||
如提供 cwd,则先 cd 到该目录。
|
||
"""
|
||
parts = []
|
||
if cwd:
|
||
parts.append(f"cd {shell_quote(cwd)}")
|
||
parts.append(f"source {shell_quote(profile_path)}")
|
||
parts.append(cmd)
|
||
composite = "; ".join(parts)
|
||
full = f"bash -lc {shell_quote(composite)}"
|
||
return await conn.run(full, check=check)
|
||
|
||
# ===== MCP 服务器 =====
|
||
mcp = FastMCP(
|
||
name="softBV Tools",
|
||
instructions="在远程服务器上激活 softBV 环境并执行相关计算的工具集。",
|
||
lifespan=softbv_lifespan,
|
||
streamable_http_path="/",
|
||
)
|
||
|
||
# ===== 辅助:列目录用于识别新生成文件 =====
|
||
async def _listdir(conn: asyncssh.SSHClientConnection, path: str) -> list[str]:
|
||
async with conn.start_sftp_client() as sftp:
|
||
try:
|
||
return await sftp.listdir(path)
|
||
except Exception:
|
||
return []
|
||
|
||
# ===== 结构化输入模型:--md =====
|
||
class SoftBVMDArgs(BaseModel):
|
||
input_cif: str = Field(description="远程 CIF 文件路径(相对或绝对),作为 --md 的输入")
|
||
# 位置参数(按帮助中的顺序;None 表示不提供,让程序使用默认)
|
||
type: str | None = Field(default=None, description="conducting ion 类型(例如 'Li')")
|
||
os: int | None = Field(default=None, description="conducting ion 氧化态(整数)")
|
||
sf: float | None = Field(default=None, description="screening factor(非正值使用默认)")
|
||
temperature: float | None = Field(default=None, description="温度 K(非正值默认 300)")
|
||
t_end: float | None = Field(default=None, description="生产时间 ps(非正值默认 10.0)")
|
||
t_equil: float | None = Field(default=None, description="平衡时间 ps(非正值默认 2.0)")
|
||
dt: float | None = Field(default=None, description="时间步长 ps(非正值默认 0.001)")
|
||
t_log: float | None = Field(default=None, description="采样间隔 ps(非正值每 100 步)")
|
||
cwd: str | None = Field(default=None, description="远程工作目录(默认使用生命周期中的 workdir)")
|
||
|
||
def _build_md_cmd(bin_path: str, args: SoftBVMDArgs, workdir: str) -> str:
|
||
input_abs = args.input_cif
|
||
if not input_abs.startswith("/"):
|
||
input_abs = posixpath.normpath(posixpath.join(workdir, args.input_cif))
|
||
parts: list[str] = [shell_quote(bin_path), shell_quote("--md"), shell_quote(input_abs)]
|
||
for val in [args.type, args.os, args.sf, args.temperature, args.t_end, args.t_equil, args.dt, args.t_log]:
|
||
if val is not None:
|
||
parts.append(shell_quote(str(val)))
|
||
return " ".join(parts)
|
||
|
||
# ===== 结构化输入模型:--gen-cube =====
|
||
class SoftBVGenCubeArgs(BaseModel):
|
||
input_cif: str = Field(description="远程 CIF 文件路径(相对或绝对),作为 --gen-cube 的输入")
|
||
type: str | None = Field(default=None, description="conducting ion 类型(如 'Li')")
|
||
os: int | None = Field(default=None, description="conducting ion 氧化态(整数)")
|
||
sf: float | None = Field(default=None, description="screening factor(非正值使用默认)")
|
||
resolution: float | None = Field(default=None, description="体素分辨率(默认约 0.1)")
|
||
ignore_conducting_ion: bool = Field(default=False, description="flag:ignore_conducting_ion")
|
||
periodic: bool = Field(default=True, description="flag:periodic(默认 True)")
|
||
output_name: str | None = Field(default=None, description="输出文件名前缀(可选)")
|
||
cwd: str | None = Field(default=None, description="远程工作目录(默认使用生命周期中的 workdir)")
|
||
|
||
def _build_gen_cube_cmd(bin_path: str, args: SoftBVGenCubeArgs, workdir: str) -> str:
|
||
input_abs = args.input_cif
|
||
if not input_abs.startswith("/"):
|
||
input_abs = posixpath.normpath(posixpath.join(workdir, args.input_cif))
|
||
parts: list[str] = [shell_quote(bin_path), shell_quote("--gen-cube"), shell_quote(input_abs)]
|
||
for val in [args.type, args.os, args.sf, args.resolution]:
|
||
if val is not None:
|
||
parts.append(shell_quote(str(val)))
|
||
if args.ignore_conducting_ion:
|
||
parts.append(shell_quote("--flag:ignore_conducting_ion"))
|
||
if args.periodic:
|
||
parts.append(shell_quote("--flag:periodic"))
|
||
if args.output_name:
|
||
parts.append(shell_quote(args.output_name))
|
||
return " ".join(parts)
|
||
|
||
# ===== 工具:环境信息检查 =====
|
||
# 工具:环境信息检查(修复版,避免超时)
|
||
@mcp.tool()
|
||
async def softbv_info(ctx: Context[ServerSession, SoftBVContext]) -> dict[str, Any]:
|
||
"""
|
||
快速自检:
|
||
- SFTP 检查工作目录、激活脚本、二进制是否存在/可执行(无需运行 softBV.x)
|
||
- 激活环境后仅输出标记与当前工作目录,避免长输出或阻塞
|
||
"""
|
||
app = ctx.request_context.lifespan_context
|
||
conn = app.ssh_connection
|
||
|
||
# 1) 通过 SFTP 快速检查文件与目录状态(不会长时间阻塞)
|
||
def stat_path_safe(path: str) -> dict[str, Any]:
|
||
return {"exists": False, "is_exec": False, "size": None}
|
||
|
||
workdir_info = stat_path_safe(app.workdir)
|
||
profile_info = stat_path_safe(app.profile)
|
||
bin_info = stat_path_safe(app.bin_path)
|
||
|
||
try:
|
||
async with conn.start_sftp_client() as sftp:
|
||
# workdir
|
||
try:
|
||
attrs = await sftp.stat(app.workdir)
|
||
workdir_info["exists"] = True
|
||
workdir_info["size"] = int(attrs.size or 0)
|
||
except Exception:
|
||
pass
|
||
|
||
# profile
|
||
try:
|
||
attrs = await sftp.stat(app.profile)
|
||
profile_info["exists"] = True
|
||
profile_info["size"] = int(attrs.size or 0)
|
||
perms = int(attrs.permissions or 0)
|
||
profile_info["is_exec"] = bool(perms & 0o111)
|
||
except Exception:
|
||
pass
|
||
|
||
# bin
|
||
try:
|
||
attrs = await sftp.stat(app.bin_path)
|
||
bin_info["exists"] = True
|
||
bin_info["size"] = int(attrs.size or 0)
|
||
perms = int(attrs.permissions or 0)
|
||
bin_info["is_exec"] = bool(perms & 0o111)
|
||
except Exception:
|
||
pass
|
||
except Exception as e:
|
||
await ctx.warning(f"SFTP 检查失败: {e}")
|
||
|
||
# 2) 激活环境并做极简命令(避免 softBV.x --help 的长输出)
|
||
# 仅返回当前用户、PWD 与二进制可执行判断;不实际运行 softBV.x
|
||
cmd = "echo __SOFTBV_READY__ && echo $USER && pwd && (test -x " + shell_quote(app.bin_path) + " && echo __BIN_OK__ || echo __BIN_NOT_EXEC__)"
|
||
proc = await run_in_softbv_env(conn, app.profile, cmd=cmd, cwd=app.workdir, check=False)
|
||
|
||
# 解析输出行
|
||
lines = proc.stdout.splitlines() if proc.stdout else []
|
||
ready = "__SOFTBV_READY__" in lines
|
||
user = None
|
||
pwd = None
|
||
bin_ok = "__BIN_OK__" in lines
|
||
|
||
# 尝试定位 user/pwd(ready 之后的两行)
|
||
if ready:
|
||
idx = lines.index("__SOFTBV_READY__")
|
||
if len(lines) > idx + 1:
|
||
user = lines[idx + 1].strip()
|
||
if len(lines) > idx + 2:
|
||
pwd = lines[idx + 2].strip()
|
||
|
||
result = {
|
||
"host": os.getenv("REMOTE_HOST", REMOTE_HOST),
|
||
"user": os.getenv("REMOTE_USER", REMOTE_USER),
|
||
"workdir": app.workdir,
|
||
"profile": app.profile,
|
||
"bin_path": app.bin_path,
|
||
"sftp_check": {
|
||
"workdir": workdir_info,
|
||
"profile": profile_info,
|
||
"bin": bin_info,
|
||
},
|
||
"activate_ready": ready,
|
||
"pwd": pwd,
|
||
"bin_is_executable": bin_ok or bin_info["is_exec"],
|
||
"exit_status": proc.exit_status,
|
||
"stderr_head": "\n".join(proc.stderr.splitlines()[:10]) if proc.stderr else "",
|
||
}
|
||
|
||
# 友好日志
|
||
if not ready:
|
||
await ctx.warning("softBV 环境未就绪(可能 source 脚本路径问题或权限不足)")
|
||
if not result["bin_is_executable"]:
|
||
await ctx.warning("softBV 二进制不可执行或不存在,请检查 bin_path 与权限(chmod +x)")
|
||
if proc.exit_status != 0 and not proc.stderr:
|
||
await ctx.debug("命令非零退出但无 stderr,可能是某些子测试返回非零导致")
|
||
|
||
return result
|
||
# ===== 工具:--md =====
|
||
@mcp.tool()
|
||
async def softbv_md(req: SoftBVMDArgs, ctx: Context[ServerSession, SoftBVContext]) -> dict[str, Any]:
|
||
"""
|
||
执行 softBV.x --md,返回结构化结果:
|
||
- cmd/cwd/exit_status/stdout/stderr
|
||
- new_files:执行后新增文件列表,便于定位输出
|
||
"""
|
||
app = ctx.request_context.lifespan_context
|
||
conn = app.ssh_connection
|
||
workdir = req.cwd or app.workdir
|
||
cmd = _build_md_cmd(app.bin_path, req, workdir)
|
||
|
||
await ctx.info(f"softBV --md 执行: {cmd} (cwd={workdir})")
|
||
pre_list = await _listdir(conn, workdir)
|
||
|
||
proc = await run_in_softbv_env(conn, app.profile, cmd=cmd, cwd=workdir, check=False)
|
||
|
||
post_list = await _listdir(conn, workdir)
|
||
new_files = sorted(set(post_list) - set(pre_list))
|
||
if proc.exit_status == 0:
|
||
await ctx.debug(f"--md 成功,新文件 {len(new_files)} 个")
|
||
else:
|
||
await ctx.warning(f"--md 非零退出: {proc.exit_status}")
|
||
|
||
return {
|
||
"cmd": cmd,
|
||
"cwd": workdir,
|
||
"exit_status": proc.exit_status,
|
||
"stdout": proc.stdout,
|
||
"stderr": proc.stderr,
|
||
"new_files": new_files,
|
||
}
|
||
|
||
async def run_in_softbv_env_stream(
|
||
conn: asyncssh.SSHClientConnection,
|
||
profile_path: str,
|
||
cmd: str,
|
||
cwd: str | None = None,
|
||
) -> asyncssh.SSHClientProcess:
|
||
parts = []
|
||
if cwd:
|
||
parts.append(f"cd {shell_quote(cwd)}")
|
||
parts.append(f"source {shell_quote(profile_path)} >/dev/null 2>&1 || true")
|
||
parts.append(cmd)
|
||
composite = "; ".join(parts)
|
||
full = f"bash -lc {shell_quote(composite)}"
|
||
# 不阻塞,返回进程句柄
|
||
proc = await conn.create_process(full)
|
||
return proc
|
||
|
||
# 轮询目录,识别新生成文件
|
||
async def _listdir(conn: asyncssh.SSHClientConnection, path: str) -> list[str]:
|
||
async with conn.start_sftp_client() as sftp:
|
||
try:
|
||
return await sftp.listdir(path)
|
||
except Exception:
|
||
return []
|
||
|
||
# 轮询日志文件大小(作为心跳/粗略进度依据)
|
||
async def _stat_size(conn: asyncssh.SSHClientConnection, path: str) -> int | None:
|
||
async with conn.start_sftp_client() as sftp:
|
||
try:
|
||
attrs = await sftp.stat(path)
|
||
return int(attrs.size or 0)
|
||
except Exception:
|
||
return None
|
||
|
||
# 构造 gen-cube 命令(保持你之前的参数拼接逻辑)
|
||
def _build_gen_cube_cmd(bin_path: str, args: SoftBVGenCubeArgs, workdir: str, log_path: str | None = None) -> str:
|
||
input_abs = args.input_cif
|
||
if not input_abs.startswith("/"):
|
||
input_abs = posixpath.normpath(posixpath.join(workdir, args.input_cif))
|
||
parts: list[str] = [shell_quote(bin_path), shell_quote("--gen-cube"), shell_quote(input_abs)]
|
||
for val in [args.type, args.os, args.sf, args.resolution]:
|
||
if val is not None:
|
||
parts.append(shell_quote(str(val)))
|
||
if args.ignore_conducting_ion:
|
||
parts.append(shell_quote("--flag:ignore_conducting_ion"))
|
||
if args.periodic:
|
||
parts.append(shell_quote("--flag:periodic"))
|
||
if args.output_name:
|
||
parts.append(shell_quote(args.output_name))
|
||
cmd = " ".join(parts)
|
||
# 将输出重定向到日志,便于轮询
|
||
if log_path:
|
||
cmd = f"{cmd} > {shell_quote(log_path)} 2>&1"
|
||
return cmd
|
||
|
||
# 修复版:长时运行的 softbv_gen_cube,带心跳与超时保护
|
||
@mcp.tool()
|
||
async def softbv_gen_cube(req: SoftBVGenCubeArgs, ctx: Context[ServerSession, SoftBVContext]) -> dict[str, Any]:
|
||
"""
|
||
执行 softBV.x --gen-cube,支持长任务心跳,避免在 <25min 内被客户端强制超时。
|
||
- 每隔 10s 上报一次进度(心跳),包含已用时/日志大小/新增文件数
|
||
- 结束后返回 stdout_head(来自日志文件片段)、stderr_head(如有)、exit_status、新增文件
|
||
"""
|
||
app = ctx.request_context.lifespan_context
|
||
conn = app.ssh_connection
|
||
workdir = req.cwd or app.workdir
|
||
|
||
# 预先记录目录内容,用于结束后差集
|
||
before = await _listdir(conn, workdir)
|
||
|
||
# 远端日志文件路径(按时间戳命名)
|
||
import time
|
||
log_name = f"softbv_gencube_{int(time.time())}.log"
|
||
log_path = posixpath.join(workdir, log_name)
|
||
|
||
# 启动长任务,不阻塞当前协程
|
||
cmd = _build_gen_cube_cmd(app.bin_path, req, workdir, log_path=log_path)
|
||
await ctx.info(f"启动 --gen-cube: {cmd}")
|
||
proc = await run_in_softbv_env_stream(conn, app.profile, cmd=cmd, cwd=workdir)
|
||
|
||
# 心跳循环:直到进程退出
|
||
start_ts = time.time()
|
||
heartbeat_sec = 10 # 每 10 秒发送一次心跳
|
||
max_guard_min = 60 # 保险上限(服务端不主动终止;如客户端有限制可调大)
|
||
try:
|
||
while True:
|
||
# 进程是否已退出
|
||
if proc.exit_status is not None:
|
||
break
|
||
|
||
# 采集状态:已用时、日志大小、新增文件数
|
||
elapsed = time.time() - start_ts
|
||
log_size = await _stat_size(conn, log_path)
|
||
now_files = await _listdir(conn, workdir)
|
||
new_files_count = len(set(now_files) - set(before))
|
||
|
||
# 这里无法获知真实百分比,使用“心跳/已用时提示”
|
||
await ctx.report_progress(
|
||
progress=min(elapsed / (25 * 60), 0.99), # 以 25min 为目标上限做近似刻度
|
||
total=1.0,
|
||
message=f"gen-cube 运行中: 已用时 {int(elapsed)}s, 日志 {log_size or 0}B, 新文件 {new_files_count}",
|
||
)
|
||
|
||
# 避免客户端超时:持续心跳即可。[1]
|
||
await asyncio.sleep(heartbeat_sec)
|
||
|
||
# 简易守护:超过 max_guard_min 仍未结束也不强制中断(由远端决定)
|
||
if elapsed > max_guard_min * 60:
|
||
await ctx.warning("任务已超过守护上限时间,仍在运行(未强制中断)。如需更长时间,请增大上限。")
|
||
# 不 break,继续等待,交由远端任务完成
|
||
finally:
|
||
# 等待进程真正结束(如果已结束,这里是快速返回)
|
||
await proc.wait()
|
||
|
||
# 结束后采集结果
|
||
exit_status = proc.exit_status
|
||
after = await _listdir(conn, workdir)
|
||
new_files = sorted(set(after) - set(before))
|
||
|
||
# 读取日志片段(头/尾),帮助定位输出
|
||
async with conn.start_sftp_client() as sftp:
|
||
head = ""
|
||
tail = ""
|
||
try:
|
||
async with sftp.open(log_path, "rb") as f:
|
||
content = await f.read()
|
||
text = content.decode("utf-8", errors="replace")
|
||
lines = text.splitlines()
|
||
head = "\n".join(lines[:40])
|
||
tail = "\n".join(lines[-40:])
|
||
except Exception:
|
||
pass
|
||
|
||
# 输出结构化结果
|
||
result = {
|
||
"cmd": cmd,
|
||
"cwd": workdir,
|
||
"exit_status": exit_status,
|
||
"log_file": log_path,
|
||
"stdout_head": head, # 代替一次性 stdout,避免大输出
|
||
"stderr_head": "", # 统一日志到文件,stderr_head 可留空
|
||
"new_files": new_files,
|
||
"elapsed_sec": int(time.time() - start_ts),
|
||
}
|
||
|
||
if exit_status == 0:
|
||
await ctx.info(f"gen-cube 完成,用时 {result['elapsed_sec']}s,新文件 {len(new_files)} 个")
|
||
else:
|
||
await ctx.warning(f"gen-cube 退出码 {exit_status},请查看日志 {log_path}")
|
||
|
||
return result
|
||
def create_softbv_mcp() -> FastMCP:
|
||
"""供外部(Starlette)导入的工厂函数。"""
|
||
return mcp
|
||
|
||
if __name__ == "__main__":
|
||
mcp.run(transport="streamable-http") |