sofvBV_mcp重构v2

Embedding copy
This commit is contained in:
2025-10-22 23:59:23 +08:00
parent b9ba79d7a8
commit c0b2ec5983
8 changed files with 916 additions and 50 deletions

View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2025 gyj155
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -0,0 +1,96 @@
# Paper Semantic Search
Find similar papers using semantic search. Supports both local models (free) and OpenAI API (better quality).
## Features
- Request for papers from OpenReview (e.g., ICLR2026 submissions)
- Semantic search with example papers or text queries
- Support embedding caching
- Embed model support: Open-source (e.g., all-MiniLM-L6-v2) or OpenAI
## Quick Start
```bash
pip install -r requirements.txt
```
### 1. Prepare Papers
```python
from crawl import crawl_papers
crawl_papers(
venue_id="ICLR.cc/2026/Conference/Submission",
output_file="iclr2026_papers.json"
)
```
### 2. Search Papers
```python
from search import PaperSearcher
# Local model (free)
searcher = PaperSearcher('iclr2026_papers.json', model_type='local')
# OpenAI model (better, requires API key)
# export OPENAI_API_KEY='your-key'
# searcher = PaperSearcher('iclr2026_papers.json', model_type='openai')
searcher.compute_embeddings()
# Search with example papers that you are interested in
examples = [
{
"title": "Your paper title",
"abstract": "Your paper abstract..."
}
]
results = searcher.search(examples=examples, top_k=100)
# Or search with text query
results = searcher.search(query="interesting topics", top_k=100)
searcher.display(results, n=10)
searcher.save(results, 'results.json')
```
## How It Works
1. Paper titles and abstracts are converted to embeddings
2. Embeddings are cached automatically
3. Your query is embedded using the same model
4. Cosine similarity finds the most similar papers
5. Results are ranked by similarity score
## Cache
Embeddings are cached as `cache_<filename>_<hash>_<model>.npy`. Delete to recompute.
## Example Output
```
================================================================================
Top 100 Results (showing 10)
================================================================================
1. [0.8456] Paper a
#12345 | foundation or frontier models, including LLMs
https://openreview.net/forum?id=xxx
2. [0.8234] Paper b
#12346 | applications to robotics, autonomy, planning
https://openreview.net/forum?id=yyy
```
## Tips
- Use 1-5 example papers for best results, or a paragraph of description of your interested topic
- Local model is good enough for most cases
- OpenAI model for critical search (~$1 for 18k queries)
If it's useful, please consider giving a star~

View File

@@ -0,0 +1,66 @@
import requests
import json
import time
def fetch_submissions(venue_id, offset=0, limit=1000):
url = "https://api2.openreview.net/notes"
params = {
"content.venueid": venue_id,
"details": "replyCount,invitation",
"limit": limit,
"offset": offset,
"sort": "number:desc"
}
headers = {"User-Agent": "Mozilla/5.0"}
response = requests.get(url, params=params, headers=headers)
response.raise_for_status()
return response.json()
def crawl_papers(venue_id, output_file):
all_papers = []
offset = 0
limit = 1000
print(f"Fetching papers from {venue_id}...")
while True:
data = fetch_submissions(venue_id, offset, limit)
notes = data.get("notes", [])
if not notes:
break
for note in notes:
paper = {
"id": note.get("id"),
"number": note.get("number"),
"title": note.get("content", {}).get("title", {}).get("value", ""),
"authors": note.get("content", {}).get("authors", {}).get("value", []),
"abstract": note.get("content", {}).get("abstract", {}).get("value", ""),
"keywords": note.get("content", {}).get("keywords", {}).get("value", []),
"primary_area": note.get("content", {}).get("primary_area", {}).get("value", ""),
"forum_url": f"https://openreview.net/forum?id={note.get('id')}"
}
all_papers.append(paper)
print(f"Fetched {len(notes)} papers (total: {len(all_papers)})")
if len(notes) < limit:
break
offset += limit
time.sleep(0.5)
with open(output_file, "w", encoding="utf-8") as f:
json.dump(all_papers, f, ensure_ascii=False, indent=2)
print(f"\nTotal: {len(all_papers)} papers")
print(f"Saved to {output_file}")
return all_papers
if __name__ == "__main__":
crawl_papers(
venue_id="ICLR.cc/2026/Conference/Submission",
output_file="iclr2026_papers.json"
)

View File

@@ -0,0 +1,22 @@
from search import PaperSearcher
# Use local model (free)
searcher = PaperSearcher('iclr2026_papers.json', model_type='local')
# Or use OpenAI (better quality)
# searcher = PaperSearcher('iclr2026_papers.json', model_type='openai')
searcher.compute_embeddings()
examples = [
{
"title": "Improving Developer Emotion Classification via LLM-Based Augmentation",
"abstract": "Detecting developer emotion in the informative data stream of technical commit messages..."
},
]
results = searcher.search(examples=examples, top_k=100)
searcher.display(results, n=10)
searcher.save(results, 'results.json')

View File

@@ -0,0 +1,6 @@
requests
numpy
scikit-learn
sentence-transformers
openai

View File

@@ -0,0 +1,156 @@
import json
import numpy as np
import os
import hashlib
from pathlib import Path
from sklearn.metrics.pairwise import cosine_similarity
class PaperSearcher:
def __init__(self, papers_file, model_type="openai", api_key=None, base_url=None):
with open(papers_file, 'r', encoding='utf-8') as f:
self.papers = json.load(f)
self.model_type = model_type
self.cache_file = self._get_cache_file(papers_file, model_type)
self.embeddings = None
if model_type == "openai":
from openai import OpenAI
self.client = OpenAI(
api_key=api_key or os.getenv('OPENAI_API_KEY'),
base_url=base_url
)
self.model_name = "text-embedding-3-large"
else:
from sentence_transformers import SentenceTransformer
self.model = SentenceTransformer('all-MiniLM-L6-v2')
self.model_name = "all-MiniLM-L6-v2"
self._load_cache()
def _get_cache_file(self, papers_file, model_type):
base_name = Path(papers_file).stem
file_hash = hashlib.md5(papers_file.encode()).hexdigest()[:8]
cache_name = f"cache_{base_name}_{file_hash}_{model_type}.npy"
return str(Path(papers_file).parent / cache_name)
def _load_cache(self):
if os.path.exists(self.cache_file):
try:
self.embeddings = np.load(self.cache_file)
if len(self.embeddings) == len(self.papers):
print(f"Loaded cache: {self.embeddings.shape}")
return True
self.embeddings = None
except:
self.embeddings = None
return False
def _save_cache(self):
np.save(self.cache_file, self.embeddings)
print(f"Saved cache: {self.cache_file}")
def _create_text(self, paper):
parts = []
if paper.get('title'):
parts.append(f"Title: {paper['title']}")
if paper.get('abstract'):
parts.append(f"Abstract: {paper['abstract']}")
if paper.get('keywords'):
kw = ', '.join(paper['keywords']) if isinstance(paper['keywords'], list) else paper['keywords']
parts.append(f"Keywords: {kw}")
return ' '.join(parts)
def _embed_openai(self, texts):
if isinstance(texts, str):
texts = [texts]
embeddings = []
batch_size = 100
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
response = self.client.embeddings.create(input=batch, model=self.model_name)
embeddings.extend([item.embedding for item in response.data])
return np.array(embeddings)
def _embed_local(self, texts):
if isinstance(texts, str):
texts = [texts]
return self.model.encode(texts, show_progress_bar=len(texts) > 100)
def compute_embeddings(self, force=False):
if self.embeddings is not None and not force:
print("Using cached embeddings")
return self.embeddings
print(f"Computing embeddings ({self.model_name})...")
texts = [self._create_text(p) for p in self.papers]
if self.model_type == "openai":
self.embeddings = self._embed_openai(texts)
else:
self.embeddings = self._embed_local(texts)
print(f"Computed: {self.embeddings.shape}")
self._save_cache()
return self.embeddings
def search(self, examples=None, query=None, top_k=100):
if self.embeddings is None:
self.compute_embeddings()
if examples:
texts = []
for ex in examples:
text = f"Title: {ex['title']}"
if ex.get('abstract'):
text += f" Abstract: {ex['abstract']}"
texts.append(text)
if self.model_type == "openai":
embs = self._embed_openai(texts)
else:
embs = self._embed_local(texts)
query_emb = np.mean(embs, axis=0).reshape(1, -1)
elif query:
if self.model_type == "openai":
query_emb = self._embed_openai(query).reshape(1, -1)
else:
query_emb = self._embed_local(query).reshape(1, -1)
else:
raise ValueError("Provide either examples or query")
similarities = cosine_similarity(query_emb, self.embeddings)[0]
top_indices = np.argsort(similarities)[::-1][:top_k]
return [{
'paper': self.papers[idx],
'similarity': float(similarities[idx])
} for idx in top_indices]
def display(self, results, n=10):
print(f"\n{'='*80}")
print(f"Top {len(results)} Results (showing {min(n, len(results))})")
print(f"{'='*80}\n")
for i, result in enumerate(results[:n], 1):
paper = result['paper']
sim = result['similarity']
print(f"{i}. [{sim:.4f}] {paper['title']}")
print(f" #{paper.get('number', 'N/A')} | {paper.get('primary_area', 'N/A')}")
print(f" {paper['forum_url']}\n")
def save(self, results, output_file):
with open(output_file, 'w', encoding='utf-8') as f:
json.dump({
'model': self.model_name,
'total': len(results),
'results': results
}, f, ensure_ascii=False, indent=2)
print(f"Saved to {output_file}")

View File

@@ -10,7 +10,7 @@ from starlette.routing import Mount
from system_tools import create_system_mcp from system_tools import create_system_mcp
from materialproject_mcp import create_materials_mcp from materialproject_mcp import create_materials_mcp
from softBV import create_softbv_mcp from softBV_remake import create_softbv_mcp
# 创建 MCP 实例 # 创建 MCP 实例
system_mcp = create_system_mcp() system_mcp = create_system_mcp()
materials_mcp = create_materials_mcp() materials_mcp = create_materials_mcp()

View File

@@ -121,15 +121,26 @@ class BVPairItem(BaseModel):
occ2: float occ2: float
bv: float = Field(description="键价 (bond valence)") bv: float = Field(description="键价 (bond valence)")
class CalBVResult(BaseModel): class CalBVResult(BaseModel):
"""'--cal-bv' 命令的完整结构化解析结果""" """
'--cal-bv' 命令的优化版结构化解析结果
为了避免返回内容过大只返回关键摘要和数据预览
"""
global_instability_index: float | None = Field(None, description="全局不稳定性指数 (GII)") global_instability_index: float | None = Field(None, description="全局不稳定性指数 (GII)")
suggested_stability: str | None = Field(None, description="建议的稳定性 (例如 'stable')") suggested_stability: str | None = Field(None, description="建议的稳定性 (例如 'stable')")
bv_sums: list[BVSumItem] = Field(description="每个离子的键价和列表")
bv_pairs: list[BVPairItem] = Field(description="离子对之间的键价列表")
raw_output: str = Field(description="命令的完整原始标准输出")
# 摘要信息
total_bv_sums: int = Field(0, description="键价和 (BVS) 条目总数")
total_bv_pairs: int = Field(0, description="键价对 (BV pairs) 条目总数")
# 数据预览 (例如前 20 条)
bv_sums_preview: list[BVSumItem] = Field(description="键价和列表的预览")
bv_pairs_preview: list[BVPairItem] = Field(description="键价对列表的预览")
# 原始输出仍然保留,以备 AI 需要时自行解析
raw_output_head: str = Field(description="命令原始标准输出的前 50 行")
raw_output_tail: str = Field(description="命令原始标准输出的后 50 行")
class CubeAtomInfo(BaseModel): class CubeAtomInfo(BaseModel):
"""从 .cube 文件中解析出的单个原子信息""" """从 .cube 文件中解析出的单个原子信息"""
index: int index: int
@@ -198,6 +209,125 @@ class CalTotEnResult(BaseModel):
total_energy_eV: float | None = Field(None, description="计算出的总能量 (eV)") total_energy_eV: float | None = Field(None, description="计算出的总能量 (eV)")
screening_factor_used: float | None = Field(None, description="计算中使用的 screening factor (sf)") screening_factor_used: float | None = Field(None, description="计算中使用的 screening factor (sf)")
raw_output: str = Field(description="命令的完整原始标准输出") raw_output: str = Field(description="命令的完整原始标准输出")
class AnalyzePathwayArgs(BaseModel):
"""'--gh' (pathway analysis) 命令的输入参数"""
input_cif: str = Field(description="用于标记位点的远程 CIF 文件路径")
input_cube: str = Field(description="包含势能数据的远程 Cube 文件路径")
type: str = Field(description="导电离子类型 (例如 'Li')")
os: int = Field(description="导电离子氧化态 (例如 1)")
barrier_max: float | None = Field(None, description="最大势垒,低于此值的路径才会被显示")
periodic: bool = Field(True, description="是否将 cube 文件视为周期性的")
class PercolationThresholds(BaseModel):
"""'--gh' 命令输出中解析的渗流阈值"""
e_total: float | None = Field(None, alias="e", description="总网络的阈值")
e_1D: float | None = Field(None, description="一维传导网络的阈值")
e_2D: float | None = Field(None, description="二维传导网络的阈值")
e_3D: float | None = Field(None, description="三维传导网络的阈值")
e_loop: float | None = Field(None, description="循环路径的阈值")
class AnalyzePathwayResult(BaseModel):
"""'--gh' 命令的完整结构化解析结果"""
command_used: str
working_directory: str
exit_status: int
thresholds: PercolationThresholds | None = Field(None, description="计算出的渗流阈值")
new_files: list[str] = Field(description="执行后新生成的文件列表 (例如 *.gh.cif)")
raw_output: str = Field(description="命令的完整原始标准输出")
class MDDefaultArgs(BaseModel):
"""用于 '--md' 的默认参数模式"""
input_cif: str = Field(description="远程 CIF 文件路径")
type: str = Field(description="导电离子类型 (例如 'Li')")
os: int = Field(description="导电离子氧化态 (例如 1)")
class MDFullArgs(BaseModel):
"""用于 '--md' 的全参数模式"""
input_cif: str = Field(description="远程 CIF 文件路径")
type: str = Field(description="导电离子类型")
os: int = Field(description="导电离子氧化态")
sf: float = Field(description="自定义 screening factor")
temperature: float = Field(description="温度 (K)")
t_end: float = Field(description="生产时间 (ps)")
t_equil: float = Field(description="平衡时间 (ps)")
dt: float = Field(description="时间步长 (ps)")
t_log: float = Field(description="采样间隔 (ps)")
class MDResultProperties(BaseModel):
"""'--md' 输出中解析出的最终物理属性"""
displacement: float | None = None
diffusivity: float | None = None
mobility: float | None = None
conductivity: float | None = None
class MDResult(BaseModel):
"""'--md' 命令的完整结构化执行结果"""
command_used: str
working_directory: str
exit_status: int
final_properties: MDResultProperties | None = Field(None, description="计算出的最终电导率等属性")
log_file: str = Field(description="远程服务器上的完整日志文件路径")
output_head: str = Field(description="日志文件的前40行")
output_tail: str = Field(description="日志文件的后40行包含最终结果")
new_files: list[str] = Field(description="执行后新生成的文件列表")
elapsed_seconds: int = Field(description="任务总耗时(秒)")
class KMCDefaultArgs(BaseModel):
"""用于 '--kmc' 的默认参数模式"""
input_cif: str = Field(description="远程 CIF 文件路径")
input_cube: str = Field(description="远程 Cube 文件路径")
type: str = Field(description="导电离子类型 (例如 'Li')")
os: int = Field(description="导电离子氧化态 (例如 1)")
class KMCFullArgs(BaseModel):
"""用于 '--kmc' 的全参数模式"""
input_cif: str = Field(description="远程 CIF 文件路径")
input_cube: str = Field(description="远程 Cube 文件路径")
type: str = Field(description="导电离子类型")
os: int = Field(description="导电离子氧化态")
supercell: tuple[int, int, int] | None = Field(None, description="超胞维度 (s0, s1, s2)")
sf: float | None = Field(None, description="自定义 screening factor")
temperature: float | None = Field(None, description="温度 (K)")
t_limit: float | None = Field(None, description="时间限制 (ps)")
step_limit: int | None = Field(None, description="步数限制")
step_log: int | None = Field(None, description="日志记录步数间隔")
cutoff: float | None = Field(None, description="库仑排斥截断半径")
class KMCResultVector(BaseModel):
"""表示一个标量和一个矢量分量 (total, (x, y, z))"""
total: float
vector: tuple[float, float, float]
class KMCSiteOccupancy(BaseModel):
"""KMC 结果中的单个位点占据信息"""
site_name: str
occupancy: float
multiplicity: int
class KMCResultProperties(BaseModel):
"""'--kmc' 输出中解析出的最终物理属性"""
temperature: float | None = None
displacement: KMCResultVector | None = None
diffusivity: KMCResultVector | None = None
mobility: KMCResultVector | None = None
conductivity: KMCResultVector | None = None
site_occupancy_summary: list[KMCSiteOccupancy] = Field(default_factory=list)
class KMCResult(BaseModel):
"""'--kmc' 命令的完整结构化执行结果"""
command_used: str
working_directory: str
exit_status: int
final_properties: KMCResultProperties | None = Field(None, description="计算出的最终 KMC 属性")
log_file: str = Field(description="远程服务器上的完整日志文件路径")
output_head: str = Field(description="日志文件的前部内容")
output_tail: str = Field(description="日志文件的尾部内容,包含最终结果")
new_files: list[str] = Field(description="执行后新生成的文件列表")
elapsed_seconds: int = Field(description="任务总耗时(秒)")
# ========= 2. 辅助函数 ========= # ========= 2. 辅助函数 =========
def shell_quote(arg: str) -> str: def shell_quote(arg: str) -> str:
@@ -342,7 +472,7 @@ async def _stat_size_safe(conn: asyncssh.SSHClientConnection, path: str) -> int
def _parse_cal_bv_output(raw_text: str) -> CalBVResult: def _parse_cal_bv_output(raw_text: str) -> CalBVResult:
""" """
'--cal-bv' 的原始 stdout 解析为结构化 CalBVResult 模型 '--cal-bv' 的原始 stdout 解析为优化的包含摘要和预览 CalBVResult 模型
""" """
lines = raw_text.splitlines() lines = raw_text.splitlines()
@@ -351,10 +481,9 @@ def _parse_cal_bv_output(raw_text: str) -> CalBVResult:
bv_sums = [] bv_sums = []
bv_pairs = [] bv_pairs = []
# ...(GII 和表格解析逻辑保持不变)...
in_sum_section = False in_sum_section = False
in_pair_section = False in_pair_section = False
# 解析 GII 和稳定性
for line in lines: for line in lines:
if line.startswith("GII: Global instability index ="): if line.startswith("GII: Global instability index ="):
try: try:
@@ -363,12 +492,10 @@ def _parse_cal_bv_output(raw_text: str) -> CalBVResult:
pass pass
elif line.startswith("GII: suggested stability:"): elif line.startswith("GII: suggested stability:"):
try: try:
stability = line.split(":", 1)[-1].strip() stability = line.split(":", 1)[-1].strip().rstrip('.')
except IndexError: except IndexError:
pass pass
# 解析表格
for line in lines:
if line.startswith("BV: name type occ"): if line.startswith("BV: name type occ"):
in_sum_section = True in_sum_section = True
in_pair_section = False in_pair_section = False
@@ -379,52 +506,45 @@ def _parse_cal_bv_output(raw_text: str) -> CalBVResult:
continue continue
if line.startswith("BV:="): if line.startswith("BV:="):
continue continue
if not line.startswith("BV:"): if not line.startswith("BV:"):
in_sum_section = False in_sum_section = False
in_pair_section = False in_pair_section = False
# 使用正则表达式进行更稳健的解析,以处理可变宽度的列
if in_sum_section: if in_sum_section:
match = re.match(r"BV:\s+(\S+)\s+(\S+)\s+([\d.]+)\s+([-\d.]+)", line) match = re.match(r"BV:\s+(\S+)\s+(\S+)\s+([\d.]+)\s+([-\d.]+)", line)
if match: if match:
try: try:
bv_sums.append(BVSumItem( bv_sums.append(BVSumItem(
name=match.group(1), name=match.group(1), type=match.group(2),
type=match.group(2), occ=float(match.group(3)), bv_sum=float(match.group(4))
occ=float(match.group(3)),
bv_sum=float(match.group(4))
)) ))
except (ValueError, IndexError): except (ValueError, IndexError):
continue continue
elif in_pair_section: elif in_pair_section:
# 这是一个更复杂的行,使用 shlex 更安全地分割
try: try:
# 移除 "BV:" 前缀并分割
parts = line[3:].strip().split() parts = line[3:].strip().split()
if len(parts) == 7: if len(parts) == 7:
bv_pairs.append(BVPairItem( bv_pairs.append(BVPairItem(
name1=parts[0], name1=parts[0], type1=parts[1], occ1=float(parts[2]),
type1=parts[1], name2=parts[3], type2=parts[4], occ2=float(parts[5]),
occ1=float(parts[2]),
name2=parts[3],
type2=parts[4],
occ2=float(parts[5]),
bv=float(parts[6]) bv=float(parts[6])
)) ))
except (ValueError, IndexError): except (ValueError, IndexError):
continue continue
# --- 关键修改:返回摘要和预览,而不是完整列表 ---
return CalBVResult( return CalBVResult(
global_instability_index=gii, global_instability_index=gii,
suggested_stability=stability, suggested_stability=stability,
bv_sums=bv_sums, total_bv_sums=len(bv_sums),
bv_pairs=bv_pairs, total_bv_pairs=len(bv_pairs),
raw_output=raw_text bv_sums_preview=bv_sums[:20], # 只返回前 20 条作为预览
bv_pairs_preview=bv_pairs[:20], # 只返回前 20 条作为预览
raw_output_head="\n".join(lines[:50]),
raw_output_tail="\n".join(lines[-50:])
) )
def _parse_print_cube_output(raw_text: str) -> PrintCubeResult: def _parse_print_cube_output(raw_text: str) -> PrintCubeResult:
""" """
'--print-cube' 的原始 stdout 解析为结构化的 PrintCubeResult 模型 '--print-cube' 的原始 stdout 解析为结构化的 PrintCubeResult 模型
@@ -579,6 +699,91 @@ def _parse_cal_tot_en_output(raw_text: str) -> CalTotEnResult:
) )
def _parse_gh_output(raw_text: str) -> PercolationThresholds | None:
"""'--gh' 的 stdout 中解析 SOFTBV_GRAPH_NETWORK 行。"""
thresholds = {}
# 正则表达式用于匹配 "e_type = value" 格式
pattern = re.compile(r"SOFTBV_GRAPH_NETWORK: (e(?:_1D|_2D|_3D|_loop)?)\s*=\s*([-\d.]+)")
for line in raw_text.splitlines():
match = pattern.search(line)
if match:
key = match.group(1)
value = float(match.group(2))
# 为了与 Pydantic 模型中的 alias 对应, 'e' 保持不变
if key == "e":
thresholds["e"] = value
else:
thresholds[key] = value
if not thresholds:
return None
return PercolationThresholds.model_validate(thresholds)
def _parse_md_output(raw_text: str) -> MDResultProperties:
"""'--md' 的 stdout 中解析最终的物理属性。"""
properties = {}
# 正则表达式匹配 "MD: key = value" 格式
pattern = re.compile(r"MD: (displacement|diffusivity|mobility|conductivity)\s*=\s*([-\d.eE]+)")
for line in raw_text.splitlines():
match = pattern.search(line)
if match:
key = match.group(1)
value = float(match.group(2))
properties[key] = value
return MDResultProperties.model_validate(properties)
# 在 softbv_mcp_refactored.py 的辅助函数区域添加
def _parse_kmc_output(raw_text: str) -> KMCResultProperties:
"""'--kmc' 的 stdout 中解析最终的物理属性和位点占据信息。"""
properties = {"site_occupancy_summary": []}
in_occupancy_summary = False
# 匹配 "key = total, (x, y, z)" 或 "key = total" 的行
vector_pattern = re.compile(
r"KMC: (temperature|displacement|diffusivity|mobility|conductivity)\s*=\s*([-\d.eE]+)(?:,\s*\(([-\d.eE]+),\s*([-\d.eE]+),\s*([-\d.eE]+)\))?"
)
# 匹配位点占据信息的行
occupancy_pattern = re.compile(
r"KMC:\s*\[\s*([^\]]+)\]\s+([-\d.]+)\s+\(multiplicity\s*=\s*(\d+)\)"
)
for line in raw_text.splitlines():
if "KMC: summary of site-averaged occupancy:" in line:
in_occupancy_summary = True
continue
vec_match = vector_pattern.search(line)
if vec_match:
key, total_str, x_str, y_str, z_str = vec_match.groups()
total = float(total_str)
if key == 'temperature':
properties[key] = total
elif x_str is not None:
properties[key] = {
"total": total,
"vector": (float(x_str), float(y_str), float(z_str))
}
continue
if in_occupancy_summary:
occ_match = occupancy_pattern.search(line)
if occ_match:
site, occ, mult = occ_match.groups()
properties["site_occupancy_summary"].append({
"site_name": site.strip(),
"occupancy": float(occ),
"multiplicity": int(mult)
})
return KMCResultProperties.model_validate(properties)
# ========= 3. 生命周期管理 ========= # ========= 3. 生命周期管理 =========
@dataclass @dataclass
@@ -865,6 +1070,8 @@ async def softbv_gen_cube(
# 在 softbv_mcp_refactored.py 中,紧跟在 softbv_gen_cube 之后添加 # 在 softbv_mcp_refactored.py 中,紧跟在 softbv_gen_cube 之后添加
# 在 softbv_mcp_refactored.py 中,找到并替换 softbv_calculate_bv 函数
@mcp.tool() @mcp.tool()
async def softbv_calculate_bv( async def softbv_calculate_bv(
cif_path: str, cif_path: str,
@@ -872,15 +1079,8 @@ async def softbv_calculate_bv(
ctx: Context[ServerSession, SoftBVContext] | None = None, ctx: Context[ServerSession, SoftBVContext] | None = None,
) -> CalBVResult: ) -> CalBVResult:
""" """
执行 'softBV.x --cal-bv'计算并返回键价和 (Bond Valence Sums) 执行 'softBV.x --cal-bv'计算键价和
返回一个包含关键摘要和数据预览的结构化对象以避免内容过长
Args:
cif_path (str): 远程服务器上的 CIF 文件路径可以是相对或绝对路径
cwd (str | None): 远程工作目录如果未提供则使用默认沙箱目录
ctx: MCP 上下文由框架自动注入
Returns:
CalBVResult: 一个包含解析后的键价和全局不稳定性指数以及原始输出的结构化对象
""" """
if ctx is None: if ctx is None:
raise ValueError("Context is required for this operation.") raise ValueError("Context is required for this operation.")
@@ -889,7 +1089,6 @@ async def softbv_calculate_bv(
conn = app_ctx.ssh_connection conn = app_ctx.ssh_connection
workdir = cwd or app_ctx.workdir workdir = cwd or app_ctx.workdir
# 构建路径和命令
input_abs_path = cif_path input_abs_path = cif_path
if not input_abs_path.startswith("/"): if not input_abs_path.startswith("/"):
input_abs_path = posixpath.normpath(posixpath.join(workdir, cif_path)) input_abs_path = posixpath.normpath(posixpath.join(workdir, cif_path))
@@ -898,27 +1097,23 @@ async def softbv_calculate_bv(
await ctx.info(f"执行命令: {cmd} (工作目录: {workdir})") await ctx.info(f"执行命令: {cmd} (工作目录: {workdir})")
# 执行命令
proc = await run_in_softbv_env(conn, app_ctx.profile, cmd=cmd, cwd=workdir, check=False) proc = await run_in_softbv_env(conn, app_ctx.profile, cmd=cmd, cwd=workdir, check=False)
if proc.exit_status != 0: if proc.exit_status != 0:
await ctx.warning(f"命令执行失败,退出码: {proc.exit_status}") await ctx.warning(f"命令执行失败,退出码: {proc.exit_status}")
# 即使失败,也返回一个包含错误信息的结构 # --- 关键修改:失败时返回与新模型匹配的空结构 ---
raw_error = f"Exit Status: {proc.exit_status}\n\nSTDOUT:\n{proc.stdout}\n\nSTDERR:\n{proc.stderr}"
return CalBVResult( return CalBVResult(
global_instability_index=None, bv_sums_preview=[],
suggested_stability=None, bv_pairs_preview=[],
bv_sums=[], raw_output_head=raw_error[:1000], # 截断以防错误信息也过长
bv_pairs=[], raw_output_tail=""
raw_output=f"Exit Status: {proc.exit_status}\n\nSTDOUT:\n{proc.stdout}\n\nSTDERR:\n{proc.stderr}"
) )
# 解析成功的输出并返回结构化数据 # 解析成功的输出并返回包含摘要和预览的结构化数据
parsed_result = _parse_cal_bv_output(proc.stdout) parsed_result = _parse_cal_bv_output(proc.stdout)
return parsed_result return parsed_result
# 在 softbv_mcp_refactored.py 中,紧跟在 softbv_calculate_bv 之后添加
@mcp.tool() @mcp.tool()
async def softbv_print_cube( async def softbv_print_cube(
cube_path: str, cube_path: str,
@@ -1144,6 +1339,310 @@ async def softbv_calculate_total_energy(
# 解析并返回结果 # 解析并返回结果
return _parse_cal_tot_en_output(proc.stdout) return _parse_cal_tot_en_output(proc.stdout)
@mcp.tool()
async def softbv_analyze_pathway(
args: AnalyzePathwayArgs,
cwd: str | None = None,
ctx: Context[ServerSession, SoftBVContext] | None = None,
) -> AnalyzePathwayResult:
"""
执行 'softBV.x --gh'分析传导路径并返回渗流阈值
Args:
args (AnalyzePathwayArgs): 包含所有输入参数的结构化对象
cwd (str | None): 远程工作目录
ctx: MCP 上下文由框架自动注入
Returns:
AnalyzePathwayResult: 包含渗流阈值新生成文件列表和原始输出的结构化结果
"""
if ctx is None:
raise ValueError("Context is required for this operation.")
app_ctx = ctx.request_context.lifespan_context
conn = app_ctx.ssh_connection
workdir = cwd or app_ctx.workdir
# --- 1. 构建命令 ---
def get_abs(path: str) -> str:
return path if path.startswith("/") else posixpath.normpath(posixpath.join(workdir, path))
cmd_parts = [
shell_quote(app_ctx.bin_path),
shell_quote("--gh"),
shell_quote(get_abs(args.input_cif)),
shell_quote(get_abs(args.input_cube)),
shell_quote(args.type),
shell_quote(str(args.os)),
]
if args.barrier_max is not None:
cmd_parts.append(shell_quote(str(args.barrier_max)))
periodic_flag = "t" if args.periodic else "f"
cmd_parts.append(shell_quote(periodic_flag))
cmd = " ".join(cmd_parts)
await ctx.info(f"执行路径分析: {cmd}")
# --- 2. 执行并收集结果 ---
files_before = await _listdir_safe(conn, workdir)
proc = await run_in_softbv_env(conn, app_ctx.profile, cmd=cmd, cwd=workdir, check=False)
files_after = await _listdir_safe(conn, workdir)
new_files = sorted(set(files_after) - set(files_before))
if proc.exit_status != 0:
await ctx.warning(f"'--gh' 命令执行失败,退出码: {proc.exit_status}")
return AnalyzePathwayResult(
command_used=cmd,
working_directory=workdir,
exit_status=proc.exit_status,
thresholds=None,
new_files=new_files,
raw_output=f"Exit Status: {proc.exit_status}\n\nSTDOUT:\n{proc.stdout}\n\nSTDERR:\n{proc.stderr}"
)
# --- 3. 解析并返回结构化结果 ---
parsed_thresholds = _parse_gh_output(proc.stdout)
if not parsed_thresholds:
await ctx.info("在输出中未找到渗流阈值信息。")
return AnalyzePathwayResult(
command_used=cmd,
working_directory=workdir,
exit_status=proc.exit_status,
thresholds=parsed_thresholds,
new_files=new_files,
raw_output=proc.stdout
)
# 在 softbv_mcp_refactored.py 的工具区域添加这个函数
@mcp.tool(name="softbv_run_md")
async def softbv_run_md(
args: MDDefaultArgs | MDFullArgs,
cwd: str | None = None,
ctx: Context[ServerSession, SoftBVContext] | None = None,
) -> MDResult:
"""
执行 'softBV.x --md' 进行分子动力学计算这是一个耗时任务
支持两种调用模式:
1. 默认模式 (MDDefaultArgs): 仅提供 CIF离子类型和氧化态
2. 全参数模式 (MDFullArgs): 提供所有 MD 参数
该工具会定期报告心跳并在任务完成后返回包含电导率等关键信息的结构化结果
"""
if ctx is None:
raise ValueError("Context is required for this operation.")
app_ctx = ctx.request_context.lifespan_context
conn = app_ctx.ssh_connection
workdir = cwd or app_ctx.workdir
# --- 1. 构造命令 ---
input_abs_path = args.input_cif
if not input_abs_path.startswith("/"):
input_abs_path = posixpath.normpath(posixpath.join(workdir, args.input_cif))
cmd_parts = [
shell_quote(app_ctx.bin_path),
shell_quote("--md"),
shell_quote(input_abs_path),
shell_quote(args.type),
shell_quote(str(args.os)),
]
if isinstance(args, MDFullArgs):
cmd_parts.extend([
shell_quote(str(args.sf)),
shell_quote(str(args.temperature)),
shell_quote(str(args.t_end)),
shell_quote(str(args.t_equil)),
shell_quote(str(args.dt)),
shell_quote(str(args.t_log)),
])
# --- 2. 准备长时任务 ---
import time
log_name = f"softbv_md_{int(time.time())}.log"
log_path = posixpath.join(workdir, log_name)
final_cmd = " ".join(cmd_parts) + f" > {shell_quote(log_path)} 2>&1"
await ctx.info(f"启动 MD 长任务: {final_cmd} (工作目录: {workdir})")
files_before = await _listdir_safe(conn, workdir)
proc = await _run_in_softbv_env_stream(conn, app_ctx.profile, cmd=final_cmd, cwd=workdir)
start_time = time.monotonic()
# --- 3. 心跳循环,防止超时 ---
try:
while proc.exit_status is None:
await asyncio.sleep(10) # 每 10 秒报告一次
elapsed = int(time.monotonic() - start_time)
log_size = await _stat_size_safe(conn, log_path)
# 使用 ctx.report_progress 作为心跳机制 [1]
await ctx.report_progress(
progress=min(elapsed / (30 * 60), 0.99), # 以30分钟为基准的近似进度
message=f"MD 计算进行中... 已用时 {elapsed} 秒, 日志大小: {log_size or 0} 字节。"
)
finally:
await proc.wait() # 确保进程结束
# --- 4. 收集并返回结果 ---
elapsed_seconds = int(time.monotonic() - start_time)
files_after = await _listdir_safe(conn, workdir)
new_files = sorted(set(files_after) - set(files_before))
# 读取日志文件的头和尾
log_head, log_tail, full_log = "", "", ""
try:
async with conn.start_sftp_client() as sftp:
async with sftp.open(log_path, "r", encoding="utf-8", errors="replace") as f:
lines = await f.readlines()
full_log = "".join(lines)
log_head = "".join(lines[:40])
log_tail = "".join(lines[-40:])
except Exception as e:
await ctx.warning(f"读取日志文件 '{log_path}' 失败: {e}")
# 从日志中解析最终属性
final_properties = _parse_md_output(full_log)
result = MDResult(
command_used=final_cmd,
working_directory=workdir,
exit_status=proc.exit_status,
final_properties=final_properties,
log_file=log_path,
output_head=log_head,
output_tail=log_tail,
new_files=new_files,
elapsed_seconds=elapsed_seconds,
)
if result.exit_status == 0:
await ctx.info(f"MD 任务成功完成, 耗时 {elapsed_seconds} 秒。")
else:
await ctx.error(f"MD 任务失败, 退出码: {result.exit_status}。请检查日志: {log_path}")
return result
# 在 softbv_mcp_refactored.py 的工具区域添加这个函数
@mcp.tool(name="softbv_run_kmc")
async def softbv_run_kmc(
args: KMCDefaultArgs | KMCFullArgs,
cwd: str | None = None,
ctx: Context[ServerSession, SoftBVContext] | None = None,
) -> KMCResult:
"""
执行 'softBV.x --kmc' 进行动力学蒙特卡洛计算这是一个耗时任务
支持两种调用模式: 默认模式和全参数模式
"""
if ctx is None:
raise ValueError("Context is required for this operation.")
app_ctx = ctx.request_context.lifespan_context
conn = app_ctx.ssh_connection
workdir = cwd or app_ctx.workdir
# --- 1. 构造命令 ---
def get_abs(path: str) -> str:
return path if path.startswith("/") else posixpath.normpath(posixpath.join(workdir, path))
cmd_parts = [
shell_quote(app_ctx.bin_path),
shell_quote("--kmc"),
shell_quote(get_abs(args.input_cif)),
shell_quote(get_abs(args.input_cube)),
shell_quote(args.type),
shell_quote(str(args.os)),
]
if isinstance(args, KMCFullArgs):
if args.supercell:
cmd_parts.extend(map(lambda x: shell_quote(str(x)), args.supercell))
# 使用一个循环来处理所有可选的浮点/整型参数
for param in ['sf', 'temperature', 't_limit', 'step_limit', 'step_log', 'cutoff']:
value = getattr(args, param, None)
if value is not None:
cmd_parts.append(shell_quote(str(value)))
# --- 2. 准备长时任务 (与 gen-cube 和 md 相同) ---
import time
log_name = f"softbv_kmc_{int(time.time())}.log"
log_path = posixpath.join(workdir, log_name)
final_cmd = " ".join(cmd_parts) + f" > {shell_quote(log_path)} 2>&1"
await ctx.info(f"启动 KMC 长任务: {final_cmd} (工作目录: {workdir})")
files_before = await _listdir_safe(conn, workdir)
proc = await _run_in_softbv_env_stream(conn, app_ctx.profile, cmd=final_cmd, cwd=workdir)
start_time = time.monotonic()
# --- 3. 心跳循环 (与 gen-cube 和 md 相同) ---
try:
while proc.exit_status is None:
await asyncio.sleep(10)
elapsed = int(time.monotonic() - start_time)
log_size = await _stat_size_safe(conn, log_path)
await ctx.report_progress(
progress=min(elapsed / (30 * 60), 0.99),
message=f"KMC 计算进行中... 已用时 {elapsed} 秒, 日志大小: {log_size or 0} 字节。"
)
finally:
await proc.wait()
# --- 4. 收集并返回结果 ---
elapsed_seconds = int(time.monotonic() - start_time)
files_after = await _listdir_safe(conn, workdir)
new_files = sorted(set(files_after) - set(files_before))
log_head, log_tail, full_log = "", "", ""
try:
async with conn.start_sftp_client() as sftp:
async with sftp.open(log_path, "r", encoding="utf-8", errors="replace") as f:
lines = await f.readlines()
full_log = "".join(lines)
log_head = "".join(lines[:40])
log_tail = "".join(lines[-40:])
except Exception as e:
await ctx.warning(f"读取日志文件 '{log_path}' 失败: {e}")
# 从完整日志中解析最终属性
final_properties = _parse_kmc_output(full_log)
result = KMCResult(
command_used=final_cmd,
working_directory=workdir,
exit_status=proc.exit_status,
final_properties=final_properties,
log_file=log_path,
output_head=log_head,
output_tail=log_tail,
new_files=new_files,
elapsed_seconds=elapsed_seconds,
)
if result.exit_status == 0:
await ctx.info(f"KMC 任务成功完成, 耗时 {elapsed_seconds} 秒。")
else:
await ctx.error(f"KMC 任务失败, 退出码: {result.exit_status}。请检查日志: {log_path}")
return result
# ========= 6. 工厂函数与主程序入口 ========= # ========= 6. 工厂函数与主程序入口 =========
def create_softbv_mcp() -> FastMCP: def create_softbv_mcp() -> FastMCP: