sofvBV_mcp重构v2
Embedding copy
This commit is contained in:
21
mcp/SearchPaperByEmbedding/LICENSE
Normal file
21
mcp/SearchPaperByEmbedding/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2025 gyj155
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
96
mcp/SearchPaperByEmbedding/README.md
Normal file
96
mcp/SearchPaperByEmbedding/README.md
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
# Paper Semantic Search
|
||||||
|
|
||||||
|
Find similar papers using semantic search. Supports both local models (free) and OpenAI API (better quality).
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- Request for papers from OpenReview (e.g., ICLR2026 submissions)
|
||||||
|
- Semantic search with example papers or text queries
|
||||||
|
- Support embedding caching
|
||||||
|
- Embed model support: Open-source (e.g., all-MiniLM-L6-v2) or OpenAI
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
### 1. Prepare Papers
|
||||||
|
|
||||||
|
```python
|
||||||
|
from crawl import crawl_papers
|
||||||
|
|
||||||
|
crawl_papers(
|
||||||
|
venue_id="ICLR.cc/2026/Conference/Submission",
|
||||||
|
output_file="iclr2026_papers.json"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Search Papers
|
||||||
|
|
||||||
|
```python
|
||||||
|
from search import PaperSearcher
|
||||||
|
|
||||||
|
# Local model (free)
|
||||||
|
searcher = PaperSearcher('iclr2026_papers.json', model_type='local')
|
||||||
|
|
||||||
|
# OpenAI model (better, requires API key)
|
||||||
|
# export OPENAI_API_KEY='your-key'
|
||||||
|
# searcher = PaperSearcher('iclr2026_papers.json', model_type='openai')
|
||||||
|
|
||||||
|
searcher.compute_embeddings()
|
||||||
|
|
||||||
|
# Search with example papers that you are interested in
|
||||||
|
examples = [
|
||||||
|
{
|
||||||
|
"title": "Your paper title",
|
||||||
|
"abstract": "Your paper abstract..."
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
results = searcher.search(examples=examples, top_k=100)
|
||||||
|
|
||||||
|
# Or search with text query
|
||||||
|
results = searcher.search(query="interesting topics", top_k=100)
|
||||||
|
|
||||||
|
searcher.display(results, n=10)
|
||||||
|
searcher.save(results, 'results.json')
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## How It Works
|
||||||
|
|
||||||
|
1. Paper titles and abstracts are converted to embeddings
|
||||||
|
2. Embeddings are cached automatically
|
||||||
|
3. Your query is embedded using the same model
|
||||||
|
4. Cosine similarity finds the most similar papers
|
||||||
|
5. Results are ranked by similarity score
|
||||||
|
|
||||||
|
## Cache
|
||||||
|
|
||||||
|
Embeddings are cached as `cache_<filename>_<hash>_<model>.npy`. Delete to recompute.
|
||||||
|
|
||||||
|
## Example Output
|
||||||
|
|
||||||
|
```
|
||||||
|
================================================================================
|
||||||
|
Top 100 Results (showing 10)
|
||||||
|
================================================================================
|
||||||
|
|
||||||
|
1. [0.8456] Paper a
|
||||||
|
#12345 | foundation or frontier models, including LLMs
|
||||||
|
https://openreview.net/forum?id=xxx
|
||||||
|
|
||||||
|
2. [0.8234] Paper b
|
||||||
|
#12346 | applications to robotics, autonomy, planning
|
||||||
|
https://openreview.net/forum?id=yyy
|
||||||
|
```
|
||||||
|
|
||||||
|
## Tips
|
||||||
|
|
||||||
|
- Use 1-5 example papers for best results, or a paragraph of description of your interested topic
|
||||||
|
- Local model is good enough for most cases
|
||||||
|
- OpenAI model for critical search (~$1 for 18k queries)
|
||||||
|
|
||||||
|
If it's useful, please consider giving a star~
|
||||||
66
mcp/SearchPaperByEmbedding/crawl.py
Normal file
66
mcp/SearchPaperByEmbedding/crawl.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
import requests
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
|
||||||
|
def fetch_submissions(venue_id, offset=0, limit=1000):
|
||||||
|
url = "https://api2.openreview.net/notes"
|
||||||
|
params = {
|
||||||
|
"content.venueid": venue_id,
|
||||||
|
"details": "replyCount,invitation",
|
||||||
|
"limit": limit,
|
||||||
|
"offset": offset,
|
||||||
|
"sort": "number:desc"
|
||||||
|
}
|
||||||
|
headers = {"User-Agent": "Mozilla/5.0"}
|
||||||
|
response = requests.get(url, params=params, headers=headers)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
def crawl_papers(venue_id, output_file):
|
||||||
|
all_papers = []
|
||||||
|
offset = 0
|
||||||
|
limit = 1000
|
||||||
|
|
||||||
|
print(f"Fetching papers from {venue_id}...")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
data = fetch_submissions(venue_id, offset, limit)
|
||||||
|
notes = data.get("notes", [])
|
||||||
|
|
||||||
|
if not notes:
|
||||||
|
break
|
||||||
|
|
||||||
|
for note in notes:
|
||||||
|
paper = {
|
||||||
|
"id": note.get("id"),
|
||||||
|
"number": note.get("number"),
|
||||||
|
"title": note.get("content", {}).get("title", {}).get("value", ""),
|
||||||
|
"authors": note.get("content", {}).get("authors", {}).get("value", []),
|
||||||
|
"abstract": note.get("content", {}).get("abstract", {}).get("value", ""),
|
||||||
|
"keywords": note.get("content", {}).get("keywords", {}).get("value", []),
|
||||||
|
"primary_area": note.get("content", {}).get("primary_area", {}).get("value", ""),
|
||||||
|
"forum_url": f"https://openreview.net/forum?id={note.get('id')}"
|
||||||
|
}
|
||||||
|
all_papers.append(paper)
|
||||||
|
|
||||||
|
print(f"Fetched {len(notes)} papers (total: {len(all_papers)})")
|
||||||
|
|
||||||
|
if len(notes) < limit:
|
||||||
|
break
|
||||||
|
|
||||||
|
offset += limit
|
||||||
|
time.sleep(0.5)
|
||||||
|
|
||||||
|
with open(output_file, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(all_papers, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
print(f"\nTotal: {len(all_papers)} papers")
|
||||||
|
print(f"Saved to {output_file}")
|
||||||
|
return all_papers
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
crawl_papers(
|
||||||
|
venue_id="ICLR.cc/2026/Conference/Submission",
|
||||||
|
output_file="iclr2026_papers.json"
|
||||||
|
)
|
||||||
|
|
||||||
22
mcp/SearchPaperByEmbedding/demo.py
Normal file
22
mcp/SearchPaperByEmbedding/demo.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
from search import PaperSearcher
|
||||||
|
|
||||||
|
# Use local model (free)
|
||||||
|
searcher = PaperSearcher('iclr2026_papers.json', model_type='local')
|
||||||
|
|
||||||
|
# Or use OpenAI (better quality)
|
||||||
|
# searcher = PaperSearcher('iclr2026_papers.json', model_type='openai')
|
||||||
|
|
||||||
|
searcher.compute_embeddings()
|
||||||
|
|
||||||
|
examples = [
|
||||||
|
{
|
||||||
|
"title": "Improving Developer Emotion Classification via LLM-Based Augmentation",
|
||||||
|
"abstract": "Detecting developer emotion in the informative data stream of technical commit messages..."
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
results = searcher.search(examples=examples, top_k=100)
|
||||||
|
|
||||||
|
searcher.display(results, n=10)
|
||||||
|
searcher.save(results, 'results.json')
|
||||||
|
|
||||||
6
mcp/SearchPaperByEmbedding/requirements.txt
Normal file
6
mcp/SearchPaperByEmbedding/requirements.txt
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
requests
|
||||||
|
numpy
|
||||||
|
scikit-learn
|
||||||
|
sentence-transformers
|
||||||
|
openai
|
||||||
|
|
||||||
156
mcp/SearchPaperByEmbedding/search.py
Normal file
156
mcp/SearchPaperByEmbedding/search.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import hashlib
|
||||||
|
from pathlib import Path
|
||||||
|
from sklearn.metrics.pairwise import cosine_similarity
|
||||||
|
|
||||||
|
class PaperSearcher:
|
||||||
|
def __init__(self, papers_file, model_type="openai", api_key=None, base_url=None):
|
||||||
|
with open(papers_file, 'r', encoding='utf-8') as f:
|
||||||
|
self.papers = json.load(f)
|
||||||
|
|
||||||
|
self.model_type = model_type
|
||||||
|
self.cache_file = self._get_cache_file(papers_file, model_type)
|
||||||
|
self.embeddings = None
|
||||||
|
|
||||||
|
if model_type == "openai":
|
||||||
|
from openai import OpenAI
|
||||||
|
self.client = OpenAI(
|
||||||
|
api_key=api_key or os.getenv('OPENAI_API_KEY'),
|
||||||
|
base_url=base_url
|
||||||
|
)
|
||||||
|
self.model_name = "text-embedding-3-large"
|
||||||
|
else:
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
self.model = SentenceTransformer('all-MiniLM-L6-v2')
|
||||||
|
self.model_name = "all-MiniLM-L6-v2"
|
||||||
|
|
||||||
|
self._load_cache()
|
||||||
|
|
||||||
|
def _get_cache_file(self, papers_file, model_type):
|
||||||
|
base_name = Path(papers_file).stem
|
||||||
|
file_hash = hashlib.md5(papers_file.encode()).hexdigest()[:8]
|
||||||
|
cache_name = f"cache_{base_name}_{file_hash}_{model_type}.npy"
|
||||||
|
return str(Path(papers_file).parent / cache_name)
|
||||||
|
|
||||||
|
def _load_cache(self):
|
||||||
|
if os.path.exists(self.cache_file):
|
||||||
|
try:
|
||||||
|
self.embeddings = np.load(self.cache_file)
|
||||||
|
if len(self.embeddings) == len(self.papers):
|
||||||
|
print(f"Loaded cache: {self.embeddings.shape}")
|
||||||
|
return True
|
||||||
|
self.embeddings = None
|
||||||
|
except:
|
||||||
|
self.embeddings = None
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _save_cache(self):
|
||||||
|
np.save(self.cache_file, self.embeddings)
|
||||||
|
print(f"Saved cache: {self.cache_file}")
|
||||||
|
|
||||||
|
def _create_text(self, paper):
|
||||||
|
parts = []
|
||||||
|
if paper.get('title'):
|
||||||
|
parts.append(f"Title: {paper['title']}")
|
||||||
|
if paper.get('abstract'):
|
||||||
|
parts.append(f"Abstract: {paper['abstract']}")
|
||||||
|
if paper.get('keywords'):
|
||||||
|
kw = ', '.join(paper['keywords']) if isinstance(paper['keywords'], list) else paper['keywords']
|
||||||
|
parts.append(f"Keywords: {kw}")
|
||||||
|
return ' '.join(parts)
|
||||||
|
|
||||||
|
def _embed_openai(self, texts):
|
||||||
|
if isinstance(texts, str):
|
||||||
|
texts = [texts]
|
||||||
|
|
||||||
|
embeddings = []
|
||||||
|
batch_size = 100
|
||||||
|
|
||||||
|
for i in range(0, len(texts), batch_size):
|
||||||
|
batch = texts[i:i + batch_size]
|
||||||
|
response = self.client.embeddings.create(input=batch, model=self.model_name)
|
||||||
|
embeddings.extend([item.embedding for item in response.data])
|
||||||
|
|
||||||
|
return np.array(embeddings)
|
||||||
|
|
||||||
|
def _embed_local(self, texts):
|
||||||
|
if isinstance(texts, str):
|
||||||
|
texts = [texts]
|
||||||
|
return self.model.encode(texts, show_progress_bar=len(texts) > 100)
|
||||||
|
|
||||||
|
def compute_embeddings(self, force=False):
|
||||||
|
if self.embeddings is not None and not force:
|
||||||
|
print("Using cached embeddings")
|
||||||
|
return self.embeddings
|
||||||
|
|
||||||
|
print(f"Computing embeddings ({self.model_name})...")
|
||||||
|
texts = [self._create_text(p) for p in self.papers]
|
||||||
|
|
||||||
|
if self.model_type == "openai":
|
||||||
|
self.embeddings = self._embed_openai(texts)
|
||||||
|
else:
|
||||||
|
self.embeddings = self._embed_local(texts)
|
||||||
|
|
||||||
|
print(f"Computed: {self.embeddings.shape}")
|
||||||
|
self._save_cache()
|
||||||
|
return self.embeddings
|
||||||
|
|
||||||
|
def search(self, examples=None, query=None, top_k=100):
|
||||||
|
if self.embeddings is None:
|
||||||
|
self.compute_embeddings()
|
||||||
|
|
||||||
|
if examples:
|
||||||
|
texts = []
|
||||||
|
for ex in examples:
|
||||||
|
text = f"Title: {ex['title']}"
|
||||||
|
if ex.get('abstract'):
|
||||||
|
text += f" Abstract: {ex['abstract']}"
|
||||||
|
texts.append(text)
|
||||||
|
|
||||||
|
if self.model_type == "openai":
|
||||||
|
embs = self._embed_openai(texts)
|
||||||
|
else:
|
||||||
|
embs = self._embed_local(texts)
|
||||||
|
|
||||||
|
query_emb = np.mean(embs, axis=0).reshape(1, -1)
|
||||||
|
|
||||||
|
elif query:
|
||||||
|
if self.model_type == "openai":
|
||||||
|
query_emb = self._embed_openai(query).reshape(1, -1)
|
||||||
|
else:
|
||||||
|
query_emb = self._embed_local(query).reshape(1, -1)
|
||||||
|
else:
|
||||||
|
raise ValueError("Provide either examples or query")
|
||||||
|
|
||||||
|
similarities = cosine_similarity(query_emb, self.embeddings)[0]
|
||||||
|
top_indices = np.argsort(similarities)[::-1][:top_k]
|
||||||
|
|
||||||
|
return [{
|
||||||
|
'paper': self.papers[idx],
|
||||||
|
'similarity': float(similarities[idx])
|
||||||
|
} for idx in top_indices]
|
||||||
|
|
||||||
|
def display(self, results, n=10):
|
||||||
|
print(f"\n{'='*80}")
|
||||||
|
print(f"Top {len(results)} Results (showing {min(n, len(results))})")
|
||||||
|
print(f"{'='*80}\n")
|
||||||
|
|
||||||
|
for i, result in enumerate(results[:n], 1):
|
||||||
|
paper = result['paper']
|
||||||
|
sim = result['similarity']
|
||||||
|
|
||||||
|
print(f"{i}. [{sim:.4f}] {paper['title']}")
|
||||||
|
print(f" #{paper.get('number', 'N/A')} | {paper.get('primary_area', 'N/A')}")
|
||||||
|
print(f" {paper['forum_url']}\n")
|
||||||
|
|
||||||
|
def save(self, results, output_file):
|
||||||
|
with open(output_file, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump({
|
||||||
|
'model': self.model_name,
|
||||||
|
'total': len(results),
|
||||||
|
'results': results
|
||||||
|
}, f, ensure_ascii=False, indent=2)
|
||||||
|
print(f"Saved to {output_file}")
|
||||||
|
|
||||||
@@ -10,7 +10,7 @@ from starlette.routing import Mount
|
|||||||
|
|
||||||
from system_tools import create_system_mcp
|
from system_tools import create_system_mcp
|
||||||
from materialproject_mcp import create_materials_mcp
|
from materialproject_mcp import create_materials_mcp
|
||||||
from softBV import create_softbv_mcp
|
from softBV_remake import create_softbv_mcp
|
||||||
# 创建 MCP 实例
|
# 创建 MCP 实例
|
||||||
system_mcp = create_system_mcp()
|
system_mcp = create_system_mcp()
|
||||||
materials_mcp = create_materials_mcp()
|
materials_mcp = create_materials_mcp()
|
||||||
|
|||||||
@@ -121,15 +121,26 @@ class BVPairItem(BaseModel):
|
|||||||
occ2: float
|
occ2: float
|
||||||
bv: float = Field(description="键价 (bond valence)")
|
bv: float = Field(description="键价 (bond valence)")
|
||||||
|
|
||||||
|
|
||||||
class CalBVResult(BaseModel):
|
class CalBVResult(BaseModel):
|
||||||
"""'--cal-bv' 命令的完整结构化解析结果"""
|
"""
|
||||||
|
'--cal-bv' 命令的优化版结构化解析结果。
|
||||||
|
为了避免返回内容过大,只返回关键摘要和数据预览。
|
||||||
|
"""
|
||||||
global_instability_index: float | None = Field(None, description="全局不稳定性指数 (GII)")
|
global_instability_index: float | None = Field(None, description="全局不稳定性指数 (GII)")
|
||||||
suggested_stability: str | None = Field(None, description="建议的稳定性 (例如 'stable')")
|
suggested_stability: str | None = Field(None, description="建议的稳定性 (例如 'stable')")
|
||||||
bv_sums: list[BVSumItem] = Field(description="每个离子的键价和列表")
|
|
||||||
bv_pairs: list[BVPairItem] = Field(description="离子对之间的键价列表")
|
|
||||||
raw_output: str = Field(description="命令的完整原始标准输出")
|
|
||||||
|
|
||||||
|
# 摘要信息
|
||||||
|
total_bv_sums: int = Field(0, description="键价和 (BVS) 条目总数")
|
||||||
|
total_bv_pairs: int = Field(0, description="键价对 (BV pairs) 条目总数")
|
||||||
|
|
||||||
|
# 数据预览 (例如前 20 条)
|
||||||
|
bv_sums_preview: list[BVSumItem] = Field(description="键价和列表的预览")
|
||||||
|
bv_pairs_preview: list[BVPairItem] = Field(description="键价对列表的预览")
|
||||||
|
|
||||||
|
# 原始输出仍然保留,以备 AI 需要时自行解析
|
||||||
|
raw_output_head: str = Field(description="命令原始标准输出的前 50 行")
|
||||||
|
raw_output_tail: str = Field(description="命令原始标准输出的后 50 行")
|
||||||
class CubeAtomInfo(BaseModel):
|
class CubeAtomInfo(BaseModel):
|
||||||
"""从 .cube 文件中解析出的单个原子信息"""
|
"""从 .cube 文件中解析出的单个原子信息"""
|
||||||
index: int
|
index: int
|
||||||
@@ -198,6 +209,125 @@ class CalTotEnResult(BaseModel):
|
|||||||
total_energy_eV: float | None = Field(None, description="计算出的总能量 (eV)")
|
total_energy_eV: float | None = Field(None, description="计算出的总能量 (eV)")
|
||||||
screening_factor_used: float | None = Field(None, description="计算中使用的 screening factor (sf)")
|
screening_factor_used: float | None = Field(None, description="计算中使用的 screening factor (sf)")
|
||||||
raw_output: str = Field(description="命令的完整原始标准输出")
|
raw_output: str = Field(description="命令的完整原始标准输出")
|
||||||
|
|
||||||
|
class AnalyzePathwayArgs(BaseModel):
|
||||||
|
"""'--gh' (pathway analysis) 命令的输入参数"""
|
||||||
|
input_cif: str = Field(description="用于标记位点的远程 CIF 文件路径")
|
||||||
|
input_cube: str = Field(description="包含势能数据的远程 Cube 文件路径")
|
||||||
|
type: str = Field(description="导电离子类型 (例如 'Li')")
|
||||||
|
os: int = Field(description="导电离子氧化态 (例如 1)")
|
||||||
|
barrier_max: float | None = Field(None, description="最大势垒,低于此值的路径才会被显示")
|
||||||
|
periodic: bool = Field(True, description="是否将 cube 文件视为周期性的")
|
||||||
|
|
||||||
|
class PercolationThresholds(BaseModel):
|
||||||
|
"""从 '--gh' 命令输出中解析的渗流阈值"""
|
||||||
|
e_total: float | None = Field(None, alias="e", description="总网络的阈值")
|
||||||
|
e_1D: float | None = Field(None, description="一维传导网络的阈值")
|
||||||
|
e_2D: float | None = Field(None, description="二维传导网络的阈值")
|
||||||
|
e_3D: float | None = Field(None, description="三维传导网络的阈值")
|
||||||
|
e_loop: float | None = Field(None, description="循环路径的阈值")
|
||||||
|
|
||||||
|
class AnalyzePathwayResult(BaseModel):
|
||||||
|
"""'--gh' 命令的完整结构化解析结果"""
|
||||||
|
command_used: str
|
||||||
|
working_directory: str
|
||||||
|
exit_status: int
|
||||||
|
thresholds: PercolationThresholds | None = Field(None, description="计算出的渗流阈值")
|
||||||
|
new_files: list[str] = Field(description="执行后新生成的文件列表 (例如 *.gh.cif)")
|
||||||
|
raw_output: str = Field(description="命令的完整原始标准输出")
|
||||||
|
|
||||||
|
|
||||||
|
class MDDefaultArgs(BaseModel):
|
||||||
|
"""用于 '--md' 的默认参数模式"""
|
||||||
|
input_cif: str = Field(description="远程 CIF 文件路径")
|
||||||
|
type: str = Field(description="导电离子类型 (例如 'Li')")
|
||||||
|
os: int = Field(description="导电离子氧化态 (例如 1)")
|
||||||
|
|
||||||
|
class MDFullArgs(BaseModel):
|
||||||
|
"""用于 '--md' 的全参数模式"""
|
||||||
|
input_cif: str = Field(description="远程 CIF 文件路径")
|
||||||
|
type: str = Field(description="导电离子类型")
|
||||||
|
os: int = Field(description="导电离子氧化态")
|
||||||
|
sf: float = Field(description="自定义 screening factor")
|
||||||
|
temperature: float = Field(description="温度 (K)")
|
||||||
|
t_end: float = Field(description="生产时间 (ps)")
|
||||||
|
t_equil: float = Field(description="平衡时间 (ps)")
|
||||||
|
dt: float = Field(description="时间步长 (ps)")
|
||||||
|
t_log: float = Field(description="采样间隔 (ps)")
|
||||||
|
|
||||||
|
class MDResultProperties(BaseModel):
|
||||||
|
"""从 '--md' 输出中解析出的最终物理属性"""
|
||||||
|
displacement: float | None = None
|
||||||
|
diffusivity: float | None = None
|
||||||
|
mobility: float | None = None
|
||||||
|
conductivity: float | None = None
|
||||||
|
|
||||||
|
class MDResult(BaseModel):
|
||||||
|
"""'--md' 命令的完整结构化执行结果"""
|
||||||
|
command_used: str
|
||||||
|
working_directory: str
|
||||||
|
exit_status: int
|
||||||
|
final_properties: MDResultProperties | None = Field(None, description="计算出的最终电导率等属性")
|
||||||
|
log_file: str = Field(description="远程服务器上的完整日志文件路径")
|
||||||
|
output_head: str = Field(description="日志文件的前40行")
|
||||||
|
output_tail: str = Field(description="日志文件的后40行,包含最终结果")
|
||||||
|
new_files: list[str] = Field(description="执行后新生成的文件列表")
|
||||||
|
elapsed_seconds: int = Field(description="任务总耗时(秒)")
|
||||||
|
|
||||||
|
|
||||||
|
class KMCDefaultArgs(BaseModel):
|
||||||
|
"""用于 '--kmc' 的默认参数模式"""
|
||||||
|
input_cif: str = Field(description="远程 CIF 文件路径")
|
||||||
|
input_cube: str = Field(description="远程 Cube 文件路径")
|
||||||
|
type: str = Field(description="导电离子类型 (例如 'Li')")
|
||||||
|
os: int = Field(description="导电离子氧化态 (例如 1)")
|
||||||
|
|
||||||
|
class KMCFullArgs(BaseModel):
|
||||||
|
"""用于 '--kmc' 的全参数模式"""
|
||||||
|
input_cif: str = Field(description="远程 CIF 文件路径")
|
||||||
|
input_cube: str = Field(description="远程 Cube 文件路径")
|
||||||
|
type: str = Field(description="导电离子类型")
|
||||||
|
os: int = Field(description="导电离子氧化态")
|
||||||
|
supercell: tuple[int, int, int] | None = Field(None, description="超胞维度 (s0, s1, s2)")
|
||||||
|
sf: float | None = Field(None, description="自定义 screening factor")
|
||||||
|
temperature: float | None = Field(None, description="温度 (K)")
|
||||||
|
t_limit: float | None = Field(None, description="时间限制 (ps)")
|
||||||
|
step_limit: int | None = Field(None, description="步数限制")
|
||||||
|
step_log: int | None = Field(None, description="日志记录步数间隔")
|
||||||
|
cutoff: float | None = Field(None, description="库仑排斥截断半径")
|
||||||
|
|
||||||
|
class KMCResultVector(BaseModel):
|
||||||
|
"""表示一个标量和一个矢量分量 (total, (x, y, z))"""
|
||||||
|
total: float
|
||||||
|
vector: tuple[float, float, float]
|
||||||
|
|
||||||
|
class KMCSiteOccupancy(BaseModel):
|
||||||
|
"""KMC 结果中的单个位点占据信息"""
|
||||||
|
site_name: str
|
||||||
|
occupancy: float
|
||||||
|
multiplicity: int
|
||||||
|
|
||||||
|
class KMCResultProperties(BaseModel):
|
||||||
|
"""从 '--kmc' 输出中解析出的最终物理属性"""
|
||||||
|
temperature: float | None = None
|
||||||
|
displacement: KMCResultVector | None = None
|
||||||
|
diffusivity: KMCResultVector | None = None
|
||||||
|
mobility: KMCResultVector | None = None
|
||||||
|
conductivity: KMCResultVector | None = None
|
||||||
|
site_occupancy_summary: list[KMCSiteOccupancy] = Field(default_factory=list)
|
||||||
|
|
||||||
|
class KMCResult(BaseModel):
|
||||||
|
"""'--kmc' 命令的完整结构化执行结果"""
|
||||||
|
command_used: str
|
||||||
|
working_directory: str
|
||||||
|
exit_status: int
|
||||||
|
final_properties: KMCResultProperties | None = Field(None, description="计算出的最终 KMC 属性")
|
||||||
|
log_file: str = Field(description="远程服务器上的完整日志文件路径")
|
||||||
|
output_head: str = Field(description="日志文件的前部内容")
|
||||||
|
output_tail: str = Field(description="日志文件的尾部内容,包含最终结果")
|
||||||
|
new_files: list[str] = Field(description="执行后新生成的文件列表")
|
||||||
|
elapsed_seconds: int = Field(description="任务总耗时(秒)")
|
||||||
|
|
||||||
# ========= 2. 辅助函数 =========
|
# ========= 2. 辅助函数 =========
|
||||||
|
|
||||||
def shell_quote(arg: str) -> str:
|
def shell_quote(arg: str) -> str:
|
||||||
@@ -342,7 +472,7 @@ async def _stat_size_safe(conn: asyncssh.SSHClientConnection, path: str) -> int
|
|||||||
|
|
||||||
def _parse_cal_bv_output(raw_text: str) -> CalBVResult:
|
def _parse_cal_bv_output(raw_text: str) -> CalBVResult:
|
||||||
"""
|
"""
|
||||||
将 '--cal-bv' 的原始 stdout 解析为结构化的 CalBVResult 模型。
|
将 '--cal-bv' 的原始 stdout 解析为优化的、包含摘要和预览的 CalBVResult 模型。
|
||||||
"""
|
"""
|
||||||
lines = raw_text.splitlines()
|
lines = raw_text.splitlines()
|
||||||
|
|
||||||
@@ -351,10 +481,9 @@ def _parse_cal_bv_output(raw_text: str) -> CalBVResult:
|
|||||||
bv_sums = []
|
bv_sums = []
|
||||||
bv_pairs = []
|
bv_pairs = []
|
||||||
|
|
||||||
|
# ...(GII 和表格解析逻辑保持不变)...
|
||||||
in_sum_section = False
|
in_sum_section = False
|
||||||
in_pair_section = False
|
in_pair_section = False
|
||||||
|
|
||||||
# 解析 GII 和稳定性
|
|
||||||
for line in lines:
|
for line in lines:
|
||||||
if line.startswith("GII: Global instability index ="):
|
if line.startswith("GII: Global instability index ="):
|
||||||
try:
|
try:
|
||||||
@@ -363,12 +492,10 @@ def _parse_cal_bv_output(raw_text: str) -> CalBVResult:
|
|||||||
pass
|
pass
|
||||||
elif line.startswith("GII: suggested stability:"):
|
elif line.startswith("GII: suggested stability:"):
|
||||||
try:
|
try:
|
||||||
stability = line.split(":", 1)[-1].strip()
|
stability = line.split(":", 1)[-1].strip().rstrip('.')
|
||||||
except IndexError:
|
except IndexError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# 解析表格
|
|
||||||
for line in lines:
|
|
||||||
if line.startswith("BV: name type occ"):
|
if line.startswith("BV: name type occ"):
|
||||||
in_sum_section = True
|
in_sum_section = True
|
||||||
in_pair_section = False
|
in_pair_section = False
|
||||||
@@ -379,52 +506,45 @@ def _parse_cal_bv_output(raw_text: str) -> CalBVResult:
|
|||||||
continue
|
continue
|
||||||
if line.startswith("BV:="):
|
if line.startswith("BV:="):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not line.startswith("BV:"):
|
if not line.startswith("BV:"):
|
||||||
in_sum_section = False
|
in_sum_section = False
|
||||||
in_pair_section = False
|
in_pair_section = False
|
||||||
|
|
||||||
# 使用正则表达式进行更稳健的解析,以处理可变宽度的列
|
|
||||||
if in_sum_section:
|
if in_sum_section:
|
||||||
match = re.match(r"BV:\s+(\S+)\s+(\S+)\s+([\d.]+)\s+([-\d.]+)", line)
|
match = re.match(r"BV:\s+(\S+)\s+(\S+)\s+([\d.]+)\s+([-\d.]+)", line)
|
||||||
if match:
|
if match:
|
||||||
try:
|
try:
|
||||||
bv_sums.append(BVSumItem(
|
bv_sums.append(BVSumItem(
|
||||||
name=match.group(1),
|
name=match.group(1), type=match.group(2),
|
||||||
type=match.group(2),
|
occ=float(match.group(3)), bv_sum=float(match.group(4))
|
||||||
occ=float(match.group(3)),
|
|
||||||
bv_sum=float(match.group(4))
|
|
||||||
))
|
))
|
||||||
except (ValueError, IndexError):
|
except (ValueError, IndexError):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
elif in_pair_section:
|
elif in_pair_section:
|
||||||
# 这是一个更复杂的行,使用 shlex 更安全地分割
|
|
||||||
try:
|
try:
|
||||||
# 移除 "BV:" 前缀并分割
|
|
||||||
parts = line[3:].strip().split()
|
parts = line[3:].strip().split()
|
||||||
if len(parts) == 7:
|
if len(parts) == 7:
|
||||||
bv_pairs.append(BVPairItem(
|
bv_pairs.append(BVPairItem(
|
||||||
name1=parts[0],
|
name1=parts[0], type1=parts[1], occ1=float(parts[2]),
|
||||||
type1=parts[1],
|
name2=parts[3], type2=parts[4], occ2=float(parts[5]),
|
||||||
occ1=float(parts[2]),
|
|
||||||
name2=parts[3],
|
|
||||||
type2=parts[4],
|
|
||||||
occ2=float(parts[5]),
|
|
||||||
bv=float(parts[6])
|
bv=float(parts[6])
|
||||||
))
|
))
|
||||||
except (ValueError, IndexError):
|
except (ValueError, IndexError):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# --- 关键修改:返回摘要和预览,而不是完整列表 ---
|
||||||
return CalBVResult(
|
return CalBVResult(
|
||||||
global_instability_index=gii,
|
global_instability_index=gii,
|
||||||
suggested_stability=stability,
|
suggested_stability=stability,
|
||||||
bv_sums=bv_sums,
|
total_bv_sums=len(bv_sums),
|
||||||
bv_pairs=bv_pairs,
|
total_bv_pairs=len(bv_pairs),
|
||||||
raw_output=raw_text
|
bv_sums_preview=bv_sums[:20], # 只返回前 20 条作为预览
|
||||||
|
bv_pairs_preview=bv_pairs[:20], # 只返回前 20 条作为预览
|
||||||
|
raw_output_head="\n".join(lines[:50]),
|
||||||
|
raw_output_tail="\n".join(lines[-50:])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _parse_print_cube_output(raw_text: str) -> PrintCubeResult:
|
def _parse_print_cube_output(raw_text: str) -> PrintCubeResult:
|
||||||
"""
|
"""
|
||||||
将 '--print-cube' 的原始 stdout 解析为结构化的 PrintCubeResult 模型。
|
将 '--print-cube' 的原始 stdout 解析为结构化的 PrintCubeResult 模型。
|
||||||
@@ -579,6 +699,91 @@ def _parse_cal_tot_en_output(raw_text: str) -> CalTotEnResult:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_gh_output(raw_text: str) -> PercolationThresholds | None:
|
||||||
|
"""从 '--gh' 的 stdout 中解析 SOFTBV_GRAPH_NETWORK 行。"""
|
||||||
|
thresholds = {}
|
||||||
|
# 正则表达式用于匹配 "e_type = value" 格式
|
||||||
|
pattern = re.compile(r"SOFTBV_GRAPH_NETWORK: (e(?:_1D|_2D|_3D|_loop)?)\s*=\s*([-\d.]+)")
|
||||||
|
|
||||||
|
for line in raw_text.splitlines():
|
||||||
|
match = pattern.search(line)
|
||||||
|
if match:
|
||||||
|
key = match.group(1)
|
||||||
|
value = float(match.group(2))
|
||||||
|
|
||||||
|
# 为了与 Pydantic 模型中的 alias 对应, 'e' 保持不变
|
||||||
|
if key == "e":
|
||||||
|
thresholds["e"] = value
|
||||||
|
else:
|
||||||
|
thresholds[key] = value
|
||||||
|
|
||||||
|
if not thresholds:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return PercolationThresholds.model_validate(thresholds)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_md_output(raw_text: str) -> MDResultProperties:
|
||||||
|
"""从 '--md' 的 stdout 中解析最终的物理属性。"""
|
||||||
|
properties = {}
|
||||||
|
# 正则表达式匹配 "MD: key = value" 格式
|
||||||
|
pattern = re.compile(r"MD: (displacement|diffusivity|mobility|conductivity)\s*=\s*([-\d.eE]+)")
|
||||||
|
|
||||||
|
for line in raw_text.splitlines():
|
||||||
|
match = pattern.search(line)
|
||||||
|
if match:
|
||||||
|
key = match.group(1)
|
||||||
|
value = float(match.group(2))
|
||||||
|
properties[key] = value
|
||||||
|
|
||||||
|
return MDResultProperties.model_validate(properties)
|
||||||
|
|
||||||
|
# 在 softbv_mcp_refactored.py 的辅助函数区域添加
|
||||||
|
|
||||||
|
def _parse_kmc_output(raw_text: str) -> KMCResultProperties:
|
||||||
|
"""从 '--kmc' 的 stdout 中解析最终的物理属性和位点占据信息。"""
|
||||||
|
properties = {"site_occupancy_summary": []}
|
||||||
|
in_occupancy_summary = False
|
||||||
|
|
||||||
|
# 匹配 "key = total, (x, y, z)" 或 "key = total" 的行
|
||||||
|
vector_pattern = re.compile(
|
||||||
|
r"KMC: (temperature|displacement|diffusivity|mobility|conductivity)\s*=\s*([-\d.eE]+)(?:,\s*\(([-\d.eE]+),\s*([-\d.eE]+),\s*([-\d.eE]+)\))?"
|
||||||
|
)
|
||||||
|
# 匹配位点占据信息的行
|
||||||
|
occupancy_pattern = re.compile(
|
||||||
|
r"KMC:\s*\[\s*([^\]]+)\]\s+([-\d.]+)\s+\(multiplicity\s*=\s*(\d+)\)"
|
||||||
|
)
|
||||||
|
|
||||||
|
for line in raw_text.splitlines():
|
||||||
|
if "KMC: summary of site-averaged occupancy:" in line:
|
||||||
|
in_occupancy_summary = True
|
||||||
|
continue
|
||||||
|
|
||||||
|
vec_match = vector_pattern.search(line)
|
||||||
|
if vec_match:
|
||||||
|
key, total_str, x_str, y_str, z_str = vec_match.groups()
|
||||||
|
total = float(total_str)
|
||||||
|
if key == 'temperature':
|
||||||
|
properties[key] = total
|
||||||
|
elif x_str is not None:
|
||||||
|
properties[key] = {
|
||||||
|
"total": total,
|
||||||
|
"vector": (float(x_str), float(y_str), float(z_str))
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
|
||||||
|
if in_occupancy_summary:
|
||||||
|
occ_match = occupancy_pattern.search(line)
|
||||||
|
if occ_match:
|
||||||
|
site, occ, mult = occ_match.groups()
|
||||||
|
properties["site_occupancy_summary"].append({
|
||||||
|
"site_name": site.strip(),
|
||||||
|
"occupancy": float(occ),
|
||||||
|
"multiplicity": int(mult)
|
||||||
|
})
|
||||||
|
|
||||||
|
return KMCResultProperties.model_validate(properties)
|
||||||
# ========= 3. 生命周期管理 =========
|
# ========= 3. 生命周期管理 =========
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -865,6 +1070,8 @@ async def softbv_gen_cube(
|
|||||||
|
|
||||||
# 在 softbv_mcp_refactored.py 中,紧跟在 softbv_gen_cube 之后添加
|
# 在 softbv_mcp_refactored.py 中,紧跟在 softbv_gen_cube 之后添加
|
||||||
|
|
||||||
|
# 在 softbv_mcp_refactored.py 中,找到并替换 softbv_calculate_bv 函数
|
||||||
|
|
||||||
@mcp.tool()
|
@mcp.tool()
|
||||||
async def softbv_calculate_bv(
|
async def softbv_calculate_bv(
|
||||||
cif_path: str,
|
cif_path: str,
|
||||||
@@ -872,15 +1079,8 @@ async def softbv_calculate_bv(
|
|||||||
ctx: Context[ServerSession, SoftBVContext] | None = None,
|
ctx: Context[ServerSession, SoftBVContext] | None = None,
|
||||||
) -> CalBVResult:
|
) -> CalBVResult:
|
||||||
"""
|
"""
|
||||||
执行 'softBV.x --cal-bv',计算并返回键价和 (Bond Valence Sums)。
|
执行 'softBV.x --cal-bv',计算键价和。
|
||||||
|
返回一个包含关键摘要和数据预览的结构化对象,以避免内容过长。
|
||||||
Args:
|
|
||||||
cif_path (str): 远程服务器上的 CIF 文件路径(可以是相对或绝对路径)。
|
|
||||||
cwd (str | None): 远程工作目录。如果未提供,则使用默认沙箱目录。
|
|
||||||
ctx: MCP 上下文,由框架自动注入。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
CalBVResult: 一个包含解析后的键价和、全局不稳定性指数以及原始输出的结构化对象。
|
|
||||||
"""
|
"""
|
||||||
if ctx is None:
|
if ctx is None:
|
||||||
raise ValueError("Context is required for this operation.")
|
raise ValueError("Context is required for this operation.")
|
||||||
@@ -889,7 +1089,6 @@ async def softbv_calculate_bv(
|
|||||||
conn = app_ctx.ssh_connection
|
conn = app_ctx.ssh_connection
|
||||||
workdir = cwd or app_ctx.workdir
|
workdir = cwd or app_ctx.workdir
|
||||||
|
|
||||||
# 构建路径和命令
|
|
||||||
input_abs_path = cif_path
|
input_abs_path = cif_path
|
||||||
if not input_abs_path.startswith("/"):
|
if not input_abs_path.startswith("/"):
|
||||||
input_abs_path = posixpath.normpath(posixpath.join(workdir, cif_path))
|
input_abs_path = posixpath.normpath(posixpath.join(workdir, cif_path))
|
||||||
@@ -898,27 +1097,23 @@ async def softbv_calculate_bv(
|
|||||||
|
|
||||||
await ctx.info(f"执行命令: {cmd} (工作目录: {workdir})")
|
await ctx.info(f"执行命令: {cmd} (工作目录: {workdir})")
|
||||||
|
|
||||||
# 执行命令
|
|
||||||
proc = await run_in_softbv_env(conn, app_ctx.profile, cmd=cmd, cwd=workdir, check=False)
|
proc = await run_in_softbv_env(conn, app_ctx.profile, cmd=cmd, cwd=workdir, check=False)
|
||||||
|
|
||||||
if proc.exit_status != 0:
|
if proc.exit_status != 0:
|
||||||
await ctx.warning(f"命令执行失败,退出码: {proc.exit_status}")
|
await ctx.warning(f"命令执行失败,退出码: {proc.exit_status}")
|
||||||
# 即使失败,也返回一个包含错误信息的结构
|
# --- 关键修改:失败时返回与新模型匹配的空结构 ---
|
||||||
|
raw_error = f"Exit Status: {proc.exit_status}\n\nSTDOUT:\n{proc.stdout}\n\nSTDERR:\n{proc.stderr}"
|
||||||
return CalBVResult(
|
return CalBVResult(
|
||||||
global_instability_index=None,
|
bv_sums_preview=[],
|
||||||
suggested_stability=None,
|
bv_pairs_preview=[],
|
||||||
bv_sums=[],
|
raw_output_head=raw_error[:1000], # 截断以防错误信息也过长
|
||||||
bv_pairs=[],
|
raw_output_tail=""
|
||||||
raw_output=f"Exit Status: {proc.exit_status}\n\nSTDOUT:\n{proc.stdout}\n\nSTDERR:\n{proc.stderr}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 解析成功的输出并返回结构化数据
|
# 解析成功的输出并返回包含摘要和预览的结构化数据
|
||||||
parsed_result = _parse_cal_bv_output(proc.stdout)
|
parsed_result = _parse_cal_bv_output(proc.stdout)
|
||||||
return parsed_result
|
return parsed_result
|
||||||
|
|
||||||
|
|
||||||
# 在 softbv_mcp_refactored.py 中,紧跟在 softbv_calculate_bv 之后添加
|
|
||||||
|
|
||||||
@mcp.tool()
|
@mcp.tool()
|
||||||
async def softbv_print_cube(
|
async def softbv_print_cube(
|
||||||
cube_path: str,
|
cube_path: str,
|
||||||
@@ -1144,6 +1339,310 @@ async def softbv_calculate_total_energy(
|
|||||||
# 解析并返回结果
|
# 解析并返回结果
|
||||||
return _parse_cal_tot_en_output(proc.stdout)
|
return _parse_cal_tot_en_output(proc.stdout)
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def softbv_analyze_pathway(
|
||||||
|
args: AnalyzePathwayArgs,
|
||||||
|
cwd: str | None = None,
|
||||||
|
ctx: Context[ServerSession, SoftBVContext] | None = None,
|
||||||
|
) -> AnalyzePathwayResult:
|
||||||
|
"""
|
||||||
|
执行 'softBV.x --gh',分析传导路径并返回渗流阈值。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args (AnalyzePathwayArgs): 包含所有输入参数的结构化对象。
|
||||||
|
cwd (str | None): 远程工作目录。
|
||||||
|
ctx: MCP 上下文,由框架自动注入。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AnalyzePathwayResult: 包含渗流阈值、新生成文件列表和原始输出的结构化结果。
|
||||||
|
"""
|
||||||
|
if ctx is None:
|
||||||
|
raise ValueError("Context is required for this operation.")
|
||||||
|
|
||||||
|
app_ctx = ctx.request_context.lifespan_context
|
||||||
|
conn = app_ctx.ssh_connection
|
||||||
|
workdir = cwd or app_ctx.workdir
|
||||||
|
|
||||||
|
# --- 1. 构建命令 ---
|
||||||
|
def get_abs(path: str) -> str:
|
||||||
|
return path if path.startswith("/") else posixpath.normpath(posixpath.join(workdir, path))
|
||||||
|
|
||||||
|
cmd_parts = [
|
||||||
|
shell_quote(app_ctx.bin_path),
|
||||||
|
shell_quote("--gh"),
|
||||||
|
shell_quote(get_abs(args.input_cif)),
|
||||||
|
shell_quote(get_abs(args.input_cube)),
|
||||||
|
shell_quote(args.type),
|
||||||
|
shell_quote(str(args.os)),
|
||||||
|
]
|
||||||
|
if args.barrier_max is not None:
|
||||||
|
cmd_parts.append(shell_quote(str(args.barrier_max)))
|
||||||
|
|
||||||
|
periodic_flag = "t" if args.periodic else "f"
|
||||||
|
cmd_parts.append(shell_quote(periodic_flag))
|
||||||
|
|
||||||
|
cmd = " ".join(cmd_parts)
|
||||||
|
|
||||||
|
await ctx.info(f"执行路径分析: {cmd}")
|
||||||
|
|
||||||
|
# --- 2. 执行并收集结果 ---
|
||||||
|
files_before = await _listdir_safe(conn, workdir)
|
||||||
|
proc = await run_in_softbv_env(conn, app_ctx.profile, cmd=cmd, cwd=workdir, check=False)
|
||||||
|
files_after = await _listdir_safe(conn, workdir)
|
||||||
|
|
||||||
|
new_files = sorted(set(files_after) - set(files_before))
|
||||||
|
|
||||||
|
if proc.exit_status != 0:
|
||||||
|
await ctx.warning(f"'--gh' 命令执行失败,退出码: {proc.exit_status}")
|
||||||
|
return AnalyzePathwayResult(
|
||||||
|
command_used=cmd,
|
||||||
|
working_directory=workdir,
|
||||||
|
exit_status=proc.exit_status,
|
||||||
|
thresholds=None,
|
||||||
|
new_files=new_files,
|
||||||
|
raw_output=f"Exit Status: {proc.exit_status}\n\nSTDOUT:\n{proc.stdout}\n\nSTDERR:\n{proc.stderr}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- 3. 解析并返回结构化结果 ---
|
||||||
|
parsed_thresholds = _parse_gh_output(proc.stdout)
|
||||||
|
if not parsed_thresholds:
|
||||||
|
await ctx.info("在输出中未找到渗流阈值信息。")
|
||||||
|
|
||||||
|
return AnalyzePathwayResult(
|
||||||
|
command_used=cmd,
|
||||||
|
working_directory=workdir,
|
||||||
|
exit_status=proc.exit_status,
|
||||||
|
thresholds=parsed_thresholds,
|
||||||
|
new_files=new_files,
|
||||||
|
raw_output=proc.stdout
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# 在 softbv_mcp_refactored.py 的工具区域添加这个函数
|
||||||
|
|
||||||
|
@mcp.tool(name="softbv_run_md")
|
||||||
|
async def softbv_run_md(
|
||||||
|
args: MDDefaultArgs | MDFullArgs,
|
||||||
|
cwd: str | None = None,
|
||||||
|
ctx: Context[ServerSession, SoftBVContext] | None = None,
|
||||||
|
) -> MDResult:
|
||||||
|
"""
|
||||||
|
执行 'softBV.x --md' 进行分子动力学计算。这是一个耗时任务。
|
||||||
|
|
||||||
|
支持两种调用模式:
|
||||||
|
1. 默认模式 (MDDefaultArgs): 仅提供 CIF、离子类型和氧化态。
|
||||||
|
2. 全参数模式 (MDFullArgs): 提供所有 MD 参数。
|
||||||
|
|
||||||
|
该工具会定期报告心跳,并在任务完成后返回包含电导率等关键信息的结构化结果。
|
||||||
|
"""
|
||||||
|
if ctx is None:
|
||||||
|
raise ValueError("Context is required for this operation.")
|
||||||
|
|
||||||
|
app_ctx = ctx.request_context.lifespan_context
|
||||||
|
conn = app_ctx.ssh_connection
|
||||||
|
workdir = cwd or app_ctx.workdir
|
||||||
|
|
||||||
|
# --- 1. 构造命令 ---
|
||||||
|
input_abs_path = args.input_cif
|
||||||
|
if not input_abs_path.startswith("/"):
|
||||||
|
input_abs_path = posixpath.normpath(posixpath.join(workdir, args.input_cif))
|
||||||
|
|
||||||
|
cmd_parts = [
|
||||||
|
shell_quote(app_ctx.bin_path),
|
||||||
|
shell_quote("--md"),
|
||||||
|
shell_quote(input_abs_path),
|
||||||
|
shell_quote(args.type),
|
||||||
|
shell_quote(str(args.os)),
|
||||||
|
]
|
||||||
|
|
||||||
|
if isinstance(args, MDFullArgs):
|
||||||
|
cmd_parts.extend([
|
||||||
|
shell_quote(str(args.sf)),
|
||||||
|
shell_quote(str(args.temperature)),
|
||||||
|
shell_quote(str(args.t_end)),
|
||||||
|
shell_quote(str(args.t_equil)),
|
||||||
|
shell_quote(str(args.dt)),
|
||||||
|
shell_quote(str(args.t_log)),
|
||||||
|
])
|
||||||
|
|
||||||
|
# --- 2. 准备长时任务 ---
|
||||||
|
import time
|
||||||
|
log_name = f"softbv_md_{int(time.time())}.log"
|
||||||
|
log_path = posixpath.join(workdir, log_name)
|
||||||
|
|
||||||
|
final_cmd = " ".join(cmd_parts) + f" > {shell_quote(log_path)} 2>&1"
|
||||||
|
|
||||||
|
await ctx.info(f"启动 MD 长任务: {final_cmd} (工作目录: {workdir})")
|
||||||
|
|
||||||
|
files_before = await _listdir_safe(conn, workdir)
|
||||||
|
proc = await _run_in_softbv_env_stream(conn, app_ctx.profile, cmd=final_cmd, cwd=workdir)
|
||||||
|
start_time = time.monotonic()
|
||||||
|
|
||||||
|
# --- 3. 心跳循环,防止超时 ---
|
||||||
|
try:
|
||||||
|
while proc.exit_status is None:
|
||||||
|
await asyncio.sleep(10) # 每 10 秒报告一次
|
||||||
|
|
||||||
|
elapsed = int(time.monotonic() - start_time)
|
||||||
|
log_size = await _stat_size_safe(conn, log_path)
|
||||||
|
|
||||||
|
# 使用 ctx.report_progress 作为心跳机制 [1]
|
||||||
|
await ctx.report_progress(
|
||||||
|
progress=min(elapsed / (30 * 60), 0.99), # 以30分钟为基准的近似进度
|
||||||
|
message=f"MD 计算进行中... 已用时 {elapsed} 秒, 日志大小: {log_size or 0} 字节。"
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
await proc.wait() # 确保进程结束
|
||||||
|
|
||||||
|
# --- 4. 收集并返回结果 ---
|
||||||
|
elapsed_seconds = int(time.monotonic() - start_time)
|
||||||
|
files_after = await _listdir_safe(conn, workdir)
|
||||||
|
new_files = sorted(set(files_after) - set(files_before))
|
||||||
|
|
||||||
|
# 读取日志文件的头和尾
|
||||||
|
log_head, log_tail, full_log = "", "", ""
|
||||||
|
try:
|
||||||
|
async with conn.start_sftp_client() as sftp:
|
||||||
|
async with sftp.open(log_path, "r", encoding="utf-8", errors="replace") as f:
|
||||||
|
lines = await f.readlines()
|
||||||
|
full_log = "".join(lines)
|
||||||
|
log_head = "".join(lines[:40])
|
||||||
|
log_tail = "".join(lines[-40:])
|
||||||
|
except Exception as e:
|
||||||
|
await ctx.warning(f"读取日志文件 '{log_path}' 失败: {e}")
|
||||||
|
|
||||||
|
# 从日志中解析最终属性
|
||||||
|
final_properties = _parse_md_output(full_log)
|
||||||
|
|
||||||
|
result = MDResult(
|
||||||
|
command_used=final_cmd,
|
||||||
|
working_directory=workdir,
|
||||||
|
exit_status=proc.exit_status,
|
||||||
|
final_properties=final_properties,
|
||||||
|
log_file=log_path,
|
||||||
|
output_head=log_head,
|
||||||
|
output_tail=log_tail,
|
||||||
|
new_files=new_files,
|
||||||
|
elapsed_seconds=elapsed_seconds,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.exit_status == 0:
|
||||||
|
await ctx.info(f"MD 任务成功完成, 耗时 {elapsed_seconds} 秒。")
|
||||||
|
else:
|
||||||
|
await ctx.error(f"MD 任务失败, 退出码: {result.exit_status}。请检查日志: {log_path}")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# 在 softbv_mcp_refactored.py 的工具区域添加这个函数
|
||||||
|
|
||||||
|
@mcp.tool(name="softbv_run_kmc")
|
||||||
|
async def softbv_run_kmc(
|
||||||
|
args: KMCDefaultArgs | KMCFullArgs,
|
||||||
|
cwd: str | None = None,
|
||||||
|
ctx: Context[ServerSession, SoftBVContext] | None = None,
|
||||||
|
) -> KMCResult:
|
||||||
|
"""
|
||||||
|
执行 'softBV.x --kmc' 进行动力学蒙特卡洛计算。这是一个耗时任务。
|
||||||
|
|
||||||
|
支持两种调用模式: 默认模式和全参数模式。
|
||||||
|
"""
|
||||||
|
if ctx is None:
|
||||||
|
raise ValueError("Context is required for this operation.")
|
||||||
|
|
||||||
|
app_ctx = ctx.request_context.lifespan_context
|
||||||
|
conn = app_ctx.ssh_connection
|
||||||
|
workdir = cwd or app_ctx.workdir
|
||||||
|
|
||||||
|
# --- 1. 构造命令 ---
|
||||||
|
def get_abs(path: str) -> str:
|
||||||
|
return path if path.startswith("/") else posixpath.normpath(posixpath.join(workdir, path))
|
||||||
|
|
||||||
|
cmd_parts = [
|
||||||
|
shell_quote(app_ctx.bin_path),
|
||||||
|
shell_quote("--kmc"),
|
||||||
|
shell_quote(get_abs(args.input_cif)),
|
||||||
|
shell_quote(get_abs(args.input_cube)),
|
||||||
|
shell_quote(args.type),
|
||||||
|
shell_quote(str(args.os)),
|
||||||
|
]
|
||||||
|
|
||||||
|
if isinstance(args, KMCFullArgs):
|
||||||
|
if args.supercell:
|
||||||
|
cmd_parts.extend(map(lambda x: shell_quote(str(x)), args.supercell))
|
||||||
|
|
||||||
|
# 使用一个循环来处理所有可选的浮点/整型参数
|
||||||
|
for param in ['sf', 'temperature', 't_limit', 'step_limit', 'step_log', 'cutoff']:
|
||||||
|
value = getattr(args, param, None)
|
||||||
|
if value is not None:
|
||||||
|
cmd_parts.append(shell_quote(str(value)))
|
||||||
|
|
||||||
|
# --- 2. 准备长时任务 (与 gen-cube 和 md 相同) ---
|
||||||
|
import time
|
||||||
|
log_name = f"softbv_kmc_{int(time.time())}.log"
|
||||||
|
log_path = posixpath.join(workdir, log_name)
|
||||||
|
|
||||||
|
final_cmd = " ".join(cmd_parts) + f" > {shell_quote(log_path)} 2>&1"
|
||||||
|
|
||||||
|
await ctx.info(f"启动 KMC 长任务: {final_cmd} (工作目录: {workdir})")
|
||||||
|
|
||||||
|
files_before = await _listdir_safe(conn, workdir)
|
||||||
|
proc = await _run_in_softbv_env_stream(conn, app_ctx.profile, cmd=final_cmd, cwd=workdir)
|
||||||
|
start_time = time.monotonic()
|
||||||
|
|
||||||
|
# --- 3. 心跳循环 (与 gen-cube 和 md 相同) ---
|
||||||
|
try:
|
||||||
|
while proc.exit_status is None:
|
||||||
|
await asyncio.sleep(10)
|
||||||
|
elapsed = int(time.monotonic() - start_time)
|
||||||
|
log_size = await _stat_size_safe(conn, log_path)
|
||||||
|
await ctx.report_progress(
|
||||||
|
progress=min(elapsed / (30 * 60), 0.99),
|
||||||
|
message=f"KMC 计算进行中... 已用时 {elapsed} 秒, 日志大小: {log_size or 0} 字节。"
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
await proc.wait()
|
||||||
|
|
||||||
|
# --- 4. 收集并返回结果 ---
|
||||||
|
elapsed_seconds = int(time.monotonic() - start_time)
|
||||||
|
files_after = await _listdir_safe(conn, workdir)
|
||||||
|
new_files = sorted(set(files_after) - set(files_before))
|
||||||
|
|
||||||
|
log_head, log_tail, full_log = "", "", ""
|
||||||
|
try:
|
||||||
|
async with conn.start_sftp_client() as sftp:
|
||||||
|
async with sftp.open(log_path, "r", encoding="utf-8", errors="replace") as f:
|
||||||
|
lines = await f.readlines()
|
||||||
|
full_log = "".join(lines)
|
||||||
|
log_head = "".join(lines[:40])
|
||||||
|
log_tail = "".join(lines[-40:])
|
||||||
|
except Exception as e:
|
||||||
|
await ctx.warning(f"读取日志文件 '{log_path}' 失败: {e}")
|
||||||
|
|
||||||
|
# 从完整日志中解析最终属性
|
||||||
|
final_properties = _parse_kmc_output(full_log)
|
||||||
|
|
||||||
|
result = KMCResult(
|
||||||
|
command_used=final_cmd,
|
||||||
|
working_directory=workdir,
|
||||||
|
exit_status=proc.exit_status,
|
||||||
|
final_properties=final_properties,
|
||||||
|
log_file=log_path,
|
||||||
|
output_head=log_head,
|
||||||
|
output_tail=log_tail,
|
||||||
|
new_files=new_files,
|
||||||
|
elapsed_seconds=elapsed_seconds,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.exit_status == 0:
|
||||||
|
await ctx.info(f"KMC 任务成功完成, 耗时 {elapsed_seconds} 秒。")
|
||||||
|
else:
|
||||||
|
await ctx.error(f"KMC 任务失败, 退出码: {result.exit_status}。请检查日志: {log_path}")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
# ========= 6. 工厂函数与主程序入口 =========
|
# ========= 6. 工厂函数与主程序入口 =========
|
||||||
|
|
||||||
def create_softbv_mcp() -> FastMCP:
|
def create_softbv_mcp() -> FastMCP:
|
||||||
Reference in New Issue
Block a user