sofvBV_mcp重构v2
Embedding copy
This commit is contained in:
21
mcp/SearchPaperByEmbedding/LICENSE
Normal file
21
mcp/SearchPaperByEmbedding/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2025 gyj155
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
96
mcp/SearchPaperByEmbedding/README.md
Normal file
96
mcp/SearchPaperByEmbedding/README.md
Normal file
@@ -0,0 +1,96 @@
|
||||
# Paper Semantic Search
|
||||
|
||||
Find similar papers using semantic search. Supports both local models (free) and OpenAI API (better quality).
|
||||
|
||||
## Features
|
||||
|
||||
- Request for papers from OpenReview (e.g., ICLR2026 submissions)
|
||||
- Semantic search with example papers or text queries
|
||||
- Support embedding caching
|
||||
- Embed model support: Open-source (e.g., all-MiniLM-L6-v2) or OpenAI
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### 1. Prepare Papers
|
||||
|
||||
```python
|
||||
from crawl import crawl_papers
|
||||
|
||||
crawl_papers(
|
||||
venue_id="ICLR.cc/2026/Conference/Submission",
|
||||
output_file="iclr2026_papers.json"
|
||||
)
|
||||
```
|
||||
|
||||
### 2. Search Papers
|
||||
|
||||
```python
|
||||
from search import PaperSearcher
|
||||
|
||||
# Local model (free)
|
||||
searcher = PaperSearcher('iclr2026_papers.json', model_type='local')
|
||||
|
||||
# OpenAI model (better, requires API key)
|
||||
# export OPENAI_API_KEY='your-key'
|
||||
# searcher = PaperSearcher('iclr2026_papers.json', model_type='openai')
|
||||
|
||||
searcher.compute_embeddings()
|
||||
|
||||
# Search with example papers that you are interested in
|
||||
examples = [
|
||||
{
|
||||
"title": "Your paper title",
|
||||
"abstract": "Your paper abstract..."
|
||||
}
|
||||
]
|
||||
|
||||
results = searcher.search(examples=examples, top_k=100)
|
||||
|
||||
# Or search with text query
|
||||
results = searcher.search(query="interesting topics", top_k=100)
|
||||
|
||||
searcher.display(results, n=10)
|
||||
searcher.save(results, 'results.json')
|
||||
```
|
||||
|
||||
|
||||
|
||||
## How It Works
|
||||
|
||||
1. Paper titles and abstracts are converted to embeddings
|
||||
2. Embeddings are cached automatically
|
||||
3. Your query is embedded using the same model
|
||||
4. Cosine similarity finds the most similar papers
|
||||
5. Results are ranked by similarity score
|
||||
|
||||
## Cache
|
||||
|
||||
Embeddings are cached as `cache_<filename>_<hash>_<model>.npy`. Delete to recompute.
|
||||
|
||||
## Example Output
|
||||
|
||||
```
|
||||
================================================================================
|
||||
Top 100 Results (showing 10)
|
||||
================================================================================
|
||||
|
||||
1. [0.8456] Paper a
|
||||
#12345 | foundation or frontier models, including LLMs
|
||||
https://openreview.net/forum?id=xxx
|
||||
|
||||
2. [0.8234] Paper b
|
||||
#12346 | applications to robotics, autonomy, planning
|
||||
https://openreview.net/forum?id=yyy
|
||||
```
|
||||
|
||||
## Tips
|
||||
|
||||
- Use 1-5 example papers for best results, or a paragraph of description of your interested topic
|
||||
- Local model is good enough for most cases
|
||||
- OpenAI model for critical search (~$1 for 18k queries)
|
||||
|
||||
If it's useful, please consider giving a star~
|
||||
66
mcp/SearchPaperByEmbedding/crawl.py
Normal file
66
mcp/SearchPaperByEmbedding/crawl.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import requests
|
||||
import json
|
||||
import time
|
||||
|
||||
def fetch_submissions(venue_id, offset=0, limit=1000):
|
||||
url = "https://api2.openreview.net/notes"
|
||||
params = {
|
||||
"content.venueid": venue_id,
|
||||
"details": "replyCount,invitation",
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
"sort": "number:desc"
|
||||
}
|
||||
headers = {"User-Agent": "Mozilla/5.0"}
|
||||
response = requests.get(url, params=params, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def crawl_papers(venue_id, output_file):
|
||||
all_papers = []
|
||||
offset = 0
|
||||
limit = 1000
|
||||
|
||||
print(f"Fetching papers from {venue_id}...")
|
||||
|
||||
while True:
|
||||
data = fetch_submissions(venue_id, offset, limit)
|
||||
notes = data.get("notes", [])
|
||||
|
||||
if not notes:
|
||||
break
|
||||
|
||||
for note in notes:
|
||||
paper = {
|
||||
"id": note.get("id"),
|
||||
"number": note.get("number"),
|
||||
"title": note.get("content", {}).get("title", {}).get("value", ""),
|
||||
"authors": note.get("content", {}).get("authors", {}).get("value", []),
|
||||
"abstract": note.get("content", {}).get("abstract", {}).get("value", ""),
|
||||
"keywords": note.get("content", {}).get("keywords", {}).get("value", []),
|
||||
"primary_area": note.get("content", {}).get("primary_area", {}).get("value", ""),
|
||||
"forum_url": f"https://openreview.net/forum?id={note.get('id')}"
|
||||
}
|
||||
all_papers.append(paper)
|
||||
|
||||
print(f"Fetched {len(notes)} papers (total: {len(all_papers)})")
|
||||
|
||||
if len(notes) < limit:
|
||||
break
|
||||
|
||||
offset += limit
|
||||
time.sleep(0.5)
|
||||
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
json.dump(all_papers, f, ensure_ascii=False, indent=2)
|
||||
|
||||
print(f"\nTotal: {len(all_papers)} papers")
|
||||
print(f"Saved to {output_file}")
|
||||
return all_papers
|
||||
|
||||
if __name__ == "__main__":
|
||||
crawl_papers(
|
||||
venue_id="ICLR.cc/2026/Conference/Submission",
|
||||
output_file="iclr2026_papers.json"
|
||||
)
|
||||
|
||||
22
mcp/SearchPaperByEmbedding/demo.py
Normal file
22
mcp/SearchPaperByEmbedding/demo.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from search import PaperSearcher
|
||||
|
||||
# Use local model (free)
|
||||
searcher = PaperSearcher('iclr2026_papers.json', model_type='local')
|
||||
|
||||
# Or use OpenAI (better quality)
|
||||
# searcher = PaperSearcher('iclr2026_papers.json', model_type='openai')
|
||||
|
||||
searcher.compute_embeddings()
|
||||
|
||||
examples = [
|
||||
{
|
||||
"title": "Improving Developer Emotion Classification via LLM-Based Augmentation",
|
||||
"abstract": "Detecting developer emotion in the informative data stream of technical commit messages..."
|
||||
},
|
||||
]
|
||||
|
||||
results = searcher.search(examples=examples, top_k=100)
|
||||
|
||||
searcher.display(results, n=10)
|
||||
searcher.save(results, 'results.json')
|
||||
|
||||
6
mcp/SearchPaperByEmbedding/requirements.txt
Normal file
6
mcp/SearchPaperByEmbedding/requirements.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
requests
|
||||
numpy
|
||||
scikit-learn
|
||||
sentence-transformers
|
||||
openai
|
||||
|
||||
156
mcp/SearchPaperByEmbedding/search.py
Normal file
156
mcp/SearchPaperByEmbedding/search.py
Normal file
@@ -0,0 +1,156 @@
|
||||
import json
|
||||
import numpy as np
|
||||
import os
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
|
||||
class PaperSearcher:
|
||||
def __init__(self, papers_file, model_type="openai", api_key=None, base_url=None):
|
||||
with open(papers_file, 'r', encoding='utf-8') as f:
|
||||
self.papers = json.load(f)
|
||||
|
||||
self.model_type = model_type
|
||||
self.cache_file = self._get_cache_file(papers_file, model_type)
|
||||
self.embeddings = None
|
||||
|
||||
if model_type == "openai":
|
||||
from openai import OpenAI
|
||||
self.client = OpenAI(
|
||||
api_key=api_key or os.getenv('OPENAI_API_KEY'),
|
||||
base_url=base_url
|
||||
)
|
||||
self.model_name = "text-embedding-3-large"
|
||||
else:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
self.model = SentenceTransformer('all-MiniLM-L6-v2')
|
||||
self.model_name = "all-MiniLM-L6-v2"
|
||||
|
||||
self._load_cache()
|
||||
|
||||
def _get_cache_file(self, papers_file, model_type):
|
||||
base_name = Path(papers_file).stem
|
||||
file_hash = hashlib.md5(papers_file.encode()).hexdigest()[:8]
|
||||
cache_name = f"cache_{base_name}_{file_hash}_{model_type}.npy"
|
||||
return str(Path(papers_file).parent / cache_name)
|
||||
|
||||
def _load_cache(self):
|
||||
if os.path.exists(self.cache_file):
|
||||
try:
|
||||
self.embeddings = np.load(self.cache_file)
|
||||
if len(self.embeddings) == len(self.papers):
|
||||
print(f"Loaded cache: {self.embeddings.shape}")
|
||||
return True
|
||||
self.embeddings = None
|
||||
except:
|
||||
self.embeddings = None
|
||||
return False
|
||||
|
||||
def _save_cache(self):
|
||||
np.save(self.cache_file, self.embeddings)
|
||||
print(f"Saved cache: {self.cache_file}")
|
||||
|
||||
def _create_text(self, paper):
|
||||
parts = []
|
||||
if paper.get('title'):
|
||||
parts.append(f"Title: {paper['title']}")
|
||||
if paper.get('abstract'):
|
||||
parts.append(f"Abstract: {paper['abstract']}")
|
||||
if paper.get('keywords'):
|
||||
kw = ', '.join(paper['keywords']) if isinstance(paper['keywords'], list) else paper['keywords']
|
||||
parts.append(f"Keywords: {kw}")
|
||||
return ' '.join(parts)
|
||||
|
||||
def _embed_openai(self, texts):
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
|
||||
embeddings = []
|
||||
batch_size = 100
|
||||
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i:i + batch_size]
|
||||
response = self.client.embeddings.create(input=batch, model=self.model_name)
|
||||
embeddings.extend([item.embedding for item in response.data])
|
||||
|
||||
return np.array(embeddings)
|
||||
|
||||
def _embed_local(self, texts):
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
return self.model.encode(texts, show_progress_bar=len(texts) > 100)
|
||||
|
||||
def compute_embeddings(self, force=False):
|
||||
if self.embeddings is not None and not force:
|
||||
print("Using cached embeddings")
|
||||
return self.embeddings
|
||||
|
||||
print(f"Computing embeddings ({self.model_name})...")
|
||||
texts = [self._create_text(p) for p in self.papers]
|
||||
|
||||
if self.model_type == "openai":
|
||||
self.embeddings = self._embed_openai(texts)
|
||||
else:
|
||||
self.embeddings = self._embed_local(texts)
|
||||
|
||||
print(f"Computed: {self.embeddings.shape}")
|
||||
self._save_cache()
|
||||
return self.embeddings
|
||||
|
||||
def search(self, examples=None, query=None, top_k=100):
|
||||
if self.embeddings is None:
|
||||
self.compute_embeddings()
|
||||
|
||||
if examples:
|
||||
texts = []
|
||||
for ex in examples:
|
||||
text = f"Title: {ex['title']}"
|
||||
if ex.get('abstract'):
|
||||
text += f" Abstract: {ex['abstract']}"
|
||||
texts.append(text)
|
||||
|
||||
if self.model_type == "openai":
|
||||
embs = self._embed_openai(texts)
|
||||
else:
|
||||
embs = self._embed_local(texts)
|
||||
|
||||
query_emb = np.mean(embs, axis=0).reshape(1, -1)
|
||||
|
||||
elif query:
|
||||
if self.model_type == "openai":
|
||||
query_emb = self._embed_openai(query).reshape(1, -1)
|
||||
else:
|
||||
query_emb = self._embed_local(query).reshape(1, -1)
|
||||
else:
|
||||
raise ValueError("Provide either examples or query")
|
||||
|
||||
similarities = cosine_similarity(query_emb, self.embeddings)[0]
|
||||
top_indices = np.argsort(similarities)[::-1][:top_k]
|
||||
|
||||
return [{
|
||||
'paper': self.papers[idx],
|
||||
'similarity': float(similarities[idx])
|
||||
} for idx in top_indices]
|
||||
|
||||
def display(self, results, n=10):
|
||||
print(f"\n{'='*80}")
|
||||
print(f"Top {len(results)} Results (showing {min(n, len(results))})")
|
||||
print(f"{'='*80}\n")
|
||||
|
||||
for i, result in enumerate(results[:n], 1):
|
||||
paper = result['paper']
|
||||
sim = result['similarity']
|
||||
|
||||
print(f"{i}. [{sim:.4f}] {paper['title']}")
|
||||
print(f" #{paper.get('number', 'N/A')} | {paper.get('primary_area', 'N/A')}")
|
||||
print(f" {paper['forum_url']}\n")
|
||||
|
||||
def save(self, results, output_file):
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
json.dump({
|
||||
'model': self.model_name,
|
||||
'total': len(results),
|
||||
'results': results
|
||||
}, f, ensure_ascii=False, indent=2)
|
||||
print(f"Saved to {output_file}")
|
||||
|
||||
@@ -10,7 +10,7 @@ from starlette.routing import Mount
|
||||
|
||||
from system_tools import create_system_mcp
|
||||
from materialproject_mcp import create_materials_mcp
|
||||
from softBV import create_softbv_mcp
|
||||
from softBV_remake import create_softbv_mcp
|
||||
# 创建 MCP 实例
|
||||
system_mcp = create_system_mcp()
|
||||
materials_mcp = create_materials_mcp()
|
||||
|
||||
@@ -121,15 +121,26 @@ class BVPairItem(BaseModel):
|
||||
occ2: float
|
||||
bv: float = Field(description="键价 (bond valence)")
|
||||
|
||||
|
||||
class CalBVResult(BaseModel):
|
||||
"""'--cal-bv' 命令的完整结构化解析结果"""
|
||||
"""
|
||||
'--cal-bv' 命令的优化版结构化解析结果。
|
||||
为了避免返回内容过大,只返回关键摘要和数据预览。
|
||||
"""
|
||||
global_instability_index: float | None = Field(None, description="全局不稳定性指数 (GII)")
|
||||
suggested_stability: str | None = Field(None, description="建议的稳定性 (例如 'stable')")
|
||||
bv_sums: list[BVSumItem] = Field(description="每个离子的键价和列表")
|
||||
bv_pairs: list[BVPairItem] = Field(description="离子对之间的键价列表")
|
||||
raw_output: str = Field(description="命令的完整原始标准输出")
|
||||
|
||||
# 摘要信息
|
||||
total_bv_sums: int = Field(0, description="键价和 (BVS) 条目总数")
|
||||
total_bv_pairs: int = Field(0, description="键价对 (BV pairs) 条目总数")
|
||||
|
||||
# 数据预览 (例如前 20 条)
|
||||
bv_sums_preview: list[BVSumItem] = Field(description="键价和列表的预览")
|
||||
bv_pairs_preview: list[BVPairItem] = Field(description="键价对列表的预览")
|
||||
|
||||
# 原始输出仍然保留,以备 AI 需要时自行解析
|
||||
raw_output_head: str = Field(description="命令原始标准输出的前 50 行")
|
||||
raw_output_tail: str = Field(description="命令原始标准输出的后 50 行")
|
||||
class CubeAtomInfo(BaseModel):
|
||||
"""从 .cube 文件中解析出的单个原子信息"""
|
||||
index: int
|
||||
@@ -198,6 +209,125 @@ class CalTotEnResult(BaseModel):
|
||||
total_energy_eV: float | None = Field(None, description="计算出的总能量 (eV)")
|
||||
screening_factor_used: float | None = Field(None, description="计算中使用的 screening factor (sf)")
|
||||
raw_output: str = Field(description="命令的完整原始标准输出")
|
||||
|
||||
class AnalyzePathwayArgs(BaseModel):
|
||||
"""'--gh' (pathway analysis) 命令的输入参数"""
|
||||
input_cif: str = Field(description="用于标记位点的远程 CIF 文件路径")
|
||||
input_cube: str = Field(description="包含势能数据的远程 Cube 文件路径")
|
||||
type: str = Field(description="导电离子类型 (例如 'Li')")
|
||||
os: int = Field(description="导电离子氧化态 (例如 1)")
|
||||
barrier_max: float | None = Field(None, description="最大势垒,低于此值的路径才会被显示")
|
||||
periodic: bool = Field(True, description="是否将 cube 文件视为周期性的")
|
||||
|
||||
class PercolationThresholds(BaseModel):
|
||||
"""从 '--gh' 命令输出中解析的渗流阈值"""
|
||||
e_total: float | None = Field(None, alias="e", description="总网络的阈值")
|
||||
e_1D: float | None = Field(None, description="一维传导网络的阈值")
|
||||
e_2D: float | None = Field(None, description="二维传导网络的阈值")
|
||||
e_3D: float | None = Field(None, description="三维传导网络的阈值")
|
||||
e_loop: float | None = Field(None, description="循环路径的阈值")
|
||||
|
||||
class AnalyzePathwayResult(BaseModel):
|
||||
"""'--gh' 命令的完整结构化解析结果"""
|
||||
command_used: str
|
||||
working_directory: str
|
||||
exit_status: int
|
||||
thresholds: PercolationThresholds | None = Field(None, description="计算出的渗流阈值")
|
||||
new_files: list[str] = Field(description="执行后新生成的文件列表 (例如 *.gh.cif)")
|
||||
raw_output: str = Field(description="命令的完整原始标准输出")
|
||||
|
||||
|
||||
class MDDefaultArgs(BaseModel):
|
||||
"""用于 '--md' 的默认参数模式"""
|
||||
input_cif: str = Field(description="远程 CIF 文件路径")
|
||||
type: str = Field(description="导电离子类型 (例如 'Li')")
|
||||
os: int = Field(description="导电离子氧化态 (例如 1)")
|
||||
|
||||
class MDFullArgs(BaseModel):
|
||||
"""用于 '--md' 的全参数模式"""
|
||||
input_cif: str = Field(description="远程 CIF 文件路径")
|
||||
type: str = Field(description="导电离子类型")
|
||||
os: int = Field(description="导电离子氧化态")
|
||||
sf: float = Field(description="自定义 screening factor")
|
||||
temperature: float = Field(description="温度 (K)")
|
||||
t_end: float = Field(description="生产时间 (ps)")
|
||||
t_equil: float = Field(description="平衡时间 (ps)")
|
||||
dt: float = Field(description="时间步长 (ps)")
|
||||
t_log: float = Field(description="采样间隔 (ps)")
|
||||
|
||||
class MDResultProperties(BaseModel):
|
||||
"""从 '--md' 输出中解析出的最终物理属性"""
|
||||
displacement: float | None = None
|
||||
diffusivity: float | None = None
|
||||
mobility: float | None = None
|
||||
conductivity: float | None = None
|
||||
|
||||
class MDResult(BaseModel):
|
||||
"""'--md' 命令的完整结构化执行结果"""
|
||||
command_used: str
|
||||
working_directory: str
|
||||
exit_status: int
|
||||
final_properties: MDResultProperties | None = Field(None, description="计算出的最终电导率等属性")
|
||||
log_file: str = Field(description="远程服务器上的完整日志文件路径")
|
||||
output_head: str = Field(description="日志文件的前40行")
|
||||
output_tail: str = Field(description="日志文件的后40行,包含最终结果")
|
||||
new_files: list[str] = Field(description="执行后新生成的文件列表")
|
||||
elapsed_seconds: int = Field(description="任务总耗时(秒)")
|
||||
|
||||
|
||||
class KMCDefaultArgs(BaseModel):
|
||||
"""用于 '--kmc' 的默认参数模式"""
|
||||
input_cif: str = Field(description="远程 CIF 文件路径")
|
||||
input_cube: str = Field(description="远程 Cube 文件路径")
|
||||
type: str = Field(description="导电离子类型 (例如 'Li')")
|
||||
os: int = Field(description="导电离子氧化态 (例如 1)")
|
||||
|
||||
class KMCFullArgs(BaseModel):
|
||||
"""用于 '--kmc' 的全参数模式"""
|
||||
input_cif: str = Field(description="远程 CIF 文件路径")
|
||||
input_cube: str = Field(description="远程 Cube 文件路径")
|
||||
type: str = Field(description="导电离子类型")
|
||||
os: int = Field(description="导电离子氧化态")
|
||||
supercell: tuple[int, int, int] | None = Field(None, description="超胞维度 (s0, s1, s2)")
|
||||
sf: float | None = Field(None, description="自定义 screening factor")
|
||||
temperature: float | None = Field(None, description="温度 (K)")
|
||||
t_limit: float | None = Field(None, description="时间限制 (ps)")
|
||||
step_limit: int | None = Field(None, description="步数限制")
|
||||
step_log: int | None = Field(None, description="日志记录步数间隔")
|
||||
cutoff: float | None = Field(None, description="库仑排斥截断半径")
|
||||
|
||||
class KMCResultVector(BaseModel):
|
||||
"""表示一个标量和一个矢量分量 (total, (x, y, z))"""
|
||||
total: float
|
||||
vector: tuple[float, float, float]
|
||||
|
||||
class KMCSiteOccupancy(BaseModel):
|
||||
"""KMC 结果中的单个位点占据信息"""
|
||||
site_name: str
|
||||
occupancy: float
|
||||
multiplicity: int
|
||||
|
||||
class KMCResultProperties(BaseModel):
|
||||
"""从 '--kmc' 输出中解析出的最终物理属性"""
|
||||
temperature: float | None = None
|
||||
displacement: KMCResultVector | None = None
|
||||
diffusivity: KMCResultVector | None = None
|
||||
mobility: KMCResultVector | None = None
|
||||
conductivity: KMCResultVector | None = None
|
||||
site_occupancy_summary: list[KMCSiteOccupancy] = Field(default_factory=list)
|
||||
|
||||
class KMCResult(BaseModel):
|
||||
"""'--kmc' 命令的完整结构化执行结果"""
|
||||
command_used: str
|
||||
working_directory: str
|
||||
exit_status: int
|
||||
final_properties: KMCResultProperties | None = Field(None, description="计算出的最终 KMC 属性")
|
||||
log_file: str = Field(description="远程服务器上的完整日志文件路径")
|
||||
output_head: str = Field(description="日志文件的前部内容")
|
||||
output_tail: str = Field(description="日志文件的尾部内容,包含最终结果")
|
||||
new_files: list[str] = Field(description="执行后新生成的文件列表")
|
||||
elapsed_seconds: int = Field(description="任务总耗时(秒)")
|
||||
|
||||
# ========= 2. 辅助函数 =========
|
||||
|
||||
def shell_quote(arg: str) -> str:
|
||||
@@ -342,7 +472,7 @@ async def _stat_size_safe(conn: asyncssh.SSHClientConnection, path: str) -> int
|
||||
|
||||
def _parse_cal_bv_output(raw_text: str) -> CalBVResult:
|
||||
"""
|
||||
将 '--cal-bv' 的原始 stdout 解析为结构化的 CalBVResult 模型。
|
||||
将 '--cal-bv' 的原始 stdout 解析为优化的、包含摘要和预览的 CalBVResult 模型。
|
||||
"""
|
||||
lines = raw_text.splitlines()
|
||||
|
||||
@@ -351,10 +481,9 @@ def _parse_cal_bv_output(raw_text: str) -> CalBVResult:
|
||||
bv_sums = []
|
||||
bv_pairs = []
|
||||
|
||||
# ...(GII 和表格解析逻辑保持不变)...
|
||||
in_sum_section = False
|
||||
in_pair_section = False
|
||||
|
||||
# 解析 GII 和稳定性
|
||||
for line in lines:
|
||||
if line.startswith("GII: Global instability index ="):
|
||||
try:
|
||||
@@ -363,12 +492,10 @@ def _parse_cal_bv_output(raw_text: str) -> CalBVResult:
|
||||
pass
|
||||
elif line.startswith("GII: suggested stability:"):
|
||||
try:
|
||||
stability = line.split(":", 1)[-1].strip()
|
||||
stability = line.split(":", 1)[-1].strip().rstrip('.')
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
# 解析表格
|
||||
for line in lines:
|
||||
if line.startswith("BV: name type occ"):
|
||||
in_sum_section = True
|
||||
in_pair_section = False
|
||||
@@ -379,52 +506,45 @@ def _parse_cal_bv_output(raw_text: str) -> CalBVResult:
|
||||
continue
|
||||
if line.startswith("BV:="):
|
||||
continue
|
||||
|
||||
if not line.startswith("BV:"):
|
||||
in_sum_section = False
|
||||
in_pair_section = False
|
||||
|
||||
# 使用正则表达式进行更稳健的解析,以处理可变宽度的列
|
||||
if in_sum_section:
|
||||
match = re.match(r"BV:\s+(\S+)\s+(\S+)\s+([\d.]+)\s+([-\d.]+)", line)
|
||||
if match:
|
||||
try:
|
||||
bv_sums.append(BVSumItem(
|
||||
name=match.group(1),
|
||||
type=match.group(2),
|
||||
occ=float(match.group(3)),
|
||||
bv_sum=float(match.group(4))
|
||||
name=match.group(1), type=match.group(2),
|
||||
occ=float(match.group(3)), bv_sum=float(match.group(4))
|
||||
))
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
|
||||
elif in_pair_section:
|
||||
# 这是一个更复杂的行,使用 shlex 更安全地分割
|
||||
try:
|
||||
# 移除 "BV:" 前缀并分割
|
||||
parts = line[3:].strip().split()
|
||||
if len(parts) == 7:
|
||||
bv_pairs.append(BVPairItem(
|
||||
name1=parts[0],
|
||||
type1=parts[1],
|
||||
occ1=float(parts[2]),
|
||||
name2=parts[3],
|
||||
type2=parts[4],
|
||||
occ2=float(parts[5]),
|
||||
name1=parts[0], type1=parts[1], occ1=float(parts[2]),
|
||||
name2=parts[3], type2=parts[4], occ2=float(parts[5]),
|
||||
bv=float(parts[6])
|
||||
))
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
|
||||
# --- 关键修改:返回摘要和预览,而不是完整列表 ---
|
||||
return CalBVResult(
|
||||
global_instability_index=gii,
|
||||
suggested_stability=stability,
|
||||
bv_sums=bv_sums,
|
||||
bv_pairs=bv_pairs,
|
||||
raw_output=raw_text
|
||||
total_bv_sums=len(bv_sums),
|
||||
total_bv_pairs=len(bv_pairs),
|
||||
bv_sums_preview=bv_sums[:20], # 只返回前 20 条作为预览
|
||||
bv_pairs_preview=bv_pairs[:20], # 只返回前 20 条作为预览
|
||||
raw_output_head="\n".join(lines[:50]),
|
||||
raw_output_tail="\n".join(lines[-50:])
|
||||
)
|
||||
|
||||
|
||||
def _parse_print_cube_output(raw_text: str) -> PrintCubeResult:
|
||||
"""
|
||||
将 '--print-cube' 的原始 stdout 解析为结构化的 PrintCubeResult 模型。
|
||||
@@ -579,6 +699,91 @@ def _parse_cal_tot_en_output(raw_text: str) -> CalTotEnResult:
|
||||
)
|
||||
|
||||
|
||||
|
||||
def _parse_gh_output(raw_text: str) -> PercolationThresholds | None:
|
||||
"""从 '--gh' 的 stdout 中解析 SOFTBV_GRAPH_NETWORK 行。"""
|
||||
thresholds = {}
|
||||
# 正则表达式用于匹配 "e_type = value" 格式
|
||||
pattern = re.compile(r"SOFTBV_GRAPH_NETWORK: (e(?:_1D|_2D|_3D|_loop)?)\s*=\s*([-\d.]+)")
|
||||
|
||||
for line in raw_text.splitlines():
|
||||
match = pattern.search(line)
|
||||
if match:
|
||||
key = match.group(1)
|
||||
value = float(match.group(2))
|
||||
|
||||
# 为了与 Pydantic 模型中的 alias 对应, 'e' 保持不变
|
||||
if key == "e":
|
||||
thresholds["e"] = value
|
||||
else:
|
||||
thresholds[key] = value
|
||||
|
||||
if not thresholds:
|
||||
return None
|
||||
|
||||
return PercolationThresholds.model_validate(thresholds)
|
||||
|
||||
|
||||
def _parse_md_output(raw_text: str) -> MDResultProperties:
|
||||
"""从 '--md' 的 stdout 中解析最终的物理属性。"""
|
||||
properties = {}
|
||||
# 正则表达式匹配 "MD: key = value" 格式
|
||||
pattern = re.compile(r"MD: (displacement|diffusivity|mobility|conductivity)\s*=\s*([-\d.eE]+)")
|
||||
|
||||
for line in raw_text.splitlines():
|
||||
match = pattern.search(line)
|
||||
if match:
|
||||
key = match.group(1)
|
||||
value = float(match.group(2))
|
||||
properties[key] = value
|
||||
|
||||
return MDResultProperties.model_validate(properties)
|
||||
|
||||
# 在 softbv_mcp_refactored.py 的辅助函数区域添加
|
||||
|
||||
def _parse_kmc_output(raw_text: str) -> KMCResultProperties:
|
||||
"""从 '--kmc' 的 stdout 中解析最终的物理属性和位点占据信息。"""
|
||||
properties = {"site_occupancy_summary": []}
|
||||
in_occupancy_summary = False
|
||||
|
||||
# 匹配 "key = total, (x, y, z)" 或 "key = total" 的行
|
||||
vector_pattern = re.compile(
|
||||
r"KMC: (temperature|displacement|diffusivity|mobility|conductivity)\s*=\s*([-\d.eE]+)(?:,\s*\(([-\d.eE]+),\s*([-\d.eE]+),\s*([-\d.eE]+)\))?"
|
||||
)
|
||||
# 匹配位点占据信息的行
|
||||
occupancy_pattern = re.compile(
|
||||
r"KMC:\s*\[\s*([^\]]+)\]\s+([-\d.]+)\s+\(multiplicity\s*=\s*(\d+)\)"
|
||||
)
|
||||
|
||||
for line in raw_text.splitlines():
|
||||
if "KMC: summary of site-averaged occupancy:" in line:
|
||||
in_occupancy_summary = True
|
||||
continue
|
||||
|
||||
vec_match = vector_pattern.search(line)
|
||||
if vec_match:
|
||||
key, total_str, x_str, y_str, z_str = vec_match.groups()
|
||||
total = float(total_str)
|
||||
if key == 'temperature':
|
||||
properties[key] = total
|
||||
elif x_str is not None:
|
||||
properties[key] = {
|
||||
"total": total,
|
||||
"vector": (float(x_str), float(y_str), float(z_str))
|
||||
}
|
||||
continue
|
||||
|
||||
if in_occupancy_summary:
|
||||
occ_match = occupancy_pattern.search(line)
|
||||
if occ_match:
|
||||
site, occ, mult = occ_match.groups()
|
||||
properties["site_occupancy_summary"].append({
|
||||
"site_name": site.strip(),
|
||||
"occupancy": float(occ),
|
||||
"multiplicity": int(mult)
|
||||
})
|
||||
|
||||
return KMCResultProperties.model_validate(properties)
|
||||
# ========= 3. 生命周期管理 =========
|
||||
|
||||
@dataclass
|
||||
@@ -865,6 +1070,8 @@ async def softbv_gen_cube(
|
||||
|
||||
# 在 softbv_mcp_refactored.py 中,紧跟在 softbv_gen_cube 之后添加
|
||||
|
||||
# 在 softbv_mcp_refactored.py 中,找到并替换 softbv_calculate_bv 函数
|
||||
|
||||
@mcp.tool()
|
||||
async def softbv_calculate_bv(
|
||||
cif_path: str,
|
||||
@@ -872,15 +1079,8 @@ async def softbv_calculate_bv(
|
||||
ctx: Context[ServerSession, SoftBVContext] | None = None,
|
||||
) -> CalBVResult:
|
||||
"""
|
||||
执行 'softBV.x --cal-bv',计算并返回键价和 (Bond Valence Sums)。
|
||||
|
||||
Args:
|
||||
cif_path (str): 远程服务器上的 CIF 文件路径(可以是相对或绝对路径)。
|
||||
cwd (str | None): 远程工作目录。如果未提供,则使用默认沙箱目录。
|
||||
ctx: MCP 上下文,由框架自动注入。
|
||||
|
||||
Returns:
|
||||
CalBVResult: 一个包含解析后的键价和、全局不稳定性指数以及原始输出的结构化对象。
|
||||
执行 'softBV.x --cal-bv',计算键价和。
|
||||
返回一个包含关键摘要和数据预览的结构化对象,以避免内容过长。
|
||||
"""
|
||||
if ctx is None:
|
||||
raise ValueError("Context is required for this operation.")
|
||||
@@ -889,7 +1089,6 @@ async def softbv_calculate_bv(
|
||||
conn = app_ctx.ssh_connection
|
||||
workdir = cwd or app_ctx.workdir
|
||||
|
||||
# 构建路径和命令
|
||||
input_abs_path = cif_path
|
||||
if not input_abs_path.startswith("/"):
|
||||
input_abs_path = posixpath.normpath(posixpath.join(workdir, cif_path))
|
||||
@@ -898,27 +1097,23 @@ async def softbv_calculate_bv(
|
||||
|
||||
await ctx.info(f"执行命令: {cmd} (工作目录: {workdir})")
|
||||
|
||||
# 执行命令
|
||||
proc = await run_in_softbv_env(conn, app_ctx.profile, cmd=cmd, cwd=workdir, check=False)
|
||||
|
||||
if proc.exit_status != 0:
|
||||
await ctx.warning(f"命令执行失败,退出码: {proc.exit_status}")
|
||||
# 即使失败,也返回一个包含错误信息的结构
|
||||
# --- 关键修改:失败时返回与新模型匹配的空结构 ---
|
||||
raw_error = f"Exit Status: {proc.exit_status}\n\nSTDOUT:\n{proc.stdout}\n\nSTDERR:\n{proc.stderr}"
|
||||
return CalBVResult(
|
||||
global_instability_index=None,
|
||||
suggested_stability=None,
|
||||
bv_sums=[],
|
||||
bv_pairs=[],
|
||||
raw_output=f"Exit Status: {proc.exit_status}\n\nSTDOUT:\n{proc.stdout}\n\nSTDERR:\n{proc.stderr}"
|
||||
bv_sums_preview=[],
|
||||
bv_pairs_preview=[],
|
||||
raw_output_head=raw_error[:1000], # 截断以防错误信息也过长
|
||||
raw_output_tail=""
|
||||
)
|
||||
|
||||
# 解析成功的输出并返回结构化数据
|
||||
# 解析成功的输出并返回包含摘要和预览的结构化数据
|
||||
parsed_result = _parse_cal_bv_output(proc.stdout)
|
||||
return parsed_result
|
||||
|
||||
|
||||
# 在 softbv_mcp_refactored.py 中,紧跟在 softbv_calculate_bv 之后添加
|
||||
|
||||
@mcp.tool()
|
||||
async def softbv_print_cube(
|
||||
cube_path: str,
|
||||
@@ -1144,6 +1339,310 @@ async def softbv_calculate_total_energy(
|
||||
# 解析并返回结果
|
||||
return _parse_cal_tot_en_output(proc.stdout)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def softbv_analyze_pathway(
|
||||
args: AnalyzePathwayArgs,
|
||||
cwd: str | None = None,
|
||||
ctx: Context[ServerSession, SoftBVContext] | None = None,
|
||||
) -> AnalyzePathwayResult:
|
||||
"""
|
||||
执行 'softBV.x --gh',分析传导路径并返回渗流阈值。
|
||||
|
||||
Args:
|
||||
args (AnalyzePathwayArgs): 包含所有输入参数的结构化对象。
|
||||
cwd (str | None): 远程工作目录。
|
||||
ctx: MCP 上下文,由框架自动注入。
|
||||
|
||||
Returns:
|
||||
AnalyzePathwayResult: 包含渗流阈值、新生成文件列表和原始输出的结构化结果。
|
||||
"""
|
||||
if ctx is None:
|
||||
raise ValueError("Context is required for this operation.")
|
||||
|
||||
app_ctx = ctx.request_context.lifespan_context
|
||||
conn = app_ctx.ssh_connection
|
||||
workdir = cwd or app_ctx.workdir
|
||||
|
||||
# --- 1. 构建命令 ---
|
||||
def get_abs(path: str) -> str:
|
||||
return path if path.startswith("/") else posixpath.normpath(posixpath.join(workdir, path))
|
||||
|
||||
cmd_parts = [
|
||||
shell_quote(app_ctx.bin_path),
|
||||
shell_quote("--gh"),
|
||||
shell_quote(get_abs(args.input_cif)),
|
||||
shell_quote(get_abs(args.input_cube)),
|
||||
shell_quote(args.type),
|
||||
shell_quote(str(args.os)),
|
||||
]
|
||||
if args.barrier_max is not None:
|
||||
cmd_parts.append(shell_quote(str(args.barrier_max)))
|
||||
|
||||
periodic_flag = "t" if args.periodic else "f"
|
||||
cmd_parts.append(shell_quote(periodic_flag))
|
||||
|
||||
cmd = " ".join(cmd_parts)
|
||||
|
||||
await ctx.info(f"执行路径分析: {cmd}")
|
||||
|
||||
# --- 2. 执行并收集结果 ---
|
||||
files_before = await _listdir_safe(conn, workdir)
|
||||
proc = await run_in_softbv_env(conn, app_ctx.profile, cmd=cmd, cwd=workdir, check=False)
|
||||
files_after = await _listdir_safe(conn, workdir)
|
||||
|
||||
new_files = sorted(set(files_after) - set(files_before))
|
||||
|
||||
if proc.exit_status != 0:
|
||||
await ctx.warning(f"'--gh' 命令执行失败,退出码: {proc.exit_status}")
|
||||
return AnalyzePathwayResult(
|
||||
command_used=cmd,
|
||||
working_directory=workdir,
|
||||
exit_status=proc.exit_status,
|
||||
thresholds=None,
|
||||
new_files=new_files,
|
||||
raw_output=f"Exit Status: {proc.exit_status}\n\nSTDOUT:\n{proc.stdout}\n\nSTDERR:\n{proc.stderr}"
|
||||
)
|
||||
|
||||
# --- 3. 解析并返回结构化结果 ---
|
||||
parsed_thresholds = _parse_gh_output(proc.stdout)
|
||||
if not parsed_thresholds:
|
||||
await ctx.info("在输出中未找到渗流阈值信息。")
|
||||
|
||||
return AnalyzePathwayResult(
|
||||
command_used=cmd,
|
||||
working_directory=workdir,
|
||||
exit_status=proc.exit_status,
|
||||
thresholds=parsed_thresholds,
|
||||
new_files=new_files,
|
||||
raw_output=proc.stdout
|
||||
)
|
||||
|
||||
|
||||
# 在 softbv_mcp_refactored.py 的工具区域添加这个函数
|
||||
|
||||
@mcp.tool(name="softbv_run_md")
|
||||
async def softbv_run_md(
|
||||
args: MDDefaultArgs | MDFullArgs,
|
||||
cwd: str | None = None,
|
||||
ctx: Context[ServerSession, SoftBVContext] | None = None,
|
||||
) -> MDResult:
|
||||
"""
|
||||
执行 'softBV.x --md' 进行分子动力学计算。这是一个耗时任务。
|
||||
|
||||
支持两种调用模式:
|
||||
1. 默认模式 (MDDefaultArgs): 仅提供 CIF、离子类型和氧化态。
|
||||
2. 全参数模式 (MDFullArgs): 提供所有 MD 参数。
|
||||
|
||||
该工具会定期报告心跳,并在任务完成后返回包含电导率等关键信息的结构化结果。
|
||||
"""
|
||||
if ctx is None:
|
||||
raise ValueError("Context is required for this operation.")
|
||||
|
||||
app_ctx = ctx.request_context.lifespan_context
|
||||
conn = app_ctx.ssh_connection
|
||||
workdir = cwd or app_ctx.workdir
|
||||
|
||||
# --- 1. 构造命令 ---
|
||||
input_abs_path = args.input_cif
|
||||
if not input_abs_path.startswith("/"):
|
||||
input_abs_path = posixpath.normpath(posixpath.join(workdir, args.input_cif))
|
||||
|
||||
cmd_parts = [
|
||||
shell_quote(app_ctx.bin_path),
|
||||
shell_quote("--md"),
|
||||
shell_quote(input_abs_path),
|
||||
shell_quote(args.type),
|
||||
shell_quote(str(args.os)),
|
||||
]
|
||||
|
||||
if isinstance(args, MDFullArgs):
|
||||
cmd_parts.extend([
|
||||
shell_quote(str(args.sf)),
|
||||
shell_quote(str(args.temperature)),
|
||||
shell_quote(str(args.t_end)),
|
||||
shell_quote(str(args.t_equil)),
|
||||
shell_quote(str(args.dt)),
|
||||
shell_quote(str(args.t_log)),
|
||||
])
|
||||
|
||||
# --- 2. 准备长时任务 ---
|
||||
import time
|
||||
log_name = f"softbv_md_{int(time.time())}.log"
|
||||
log_path = posixpath.join(workdir, log_name)
|
||||
|
||||
final_cmd = " ".join(cmd_parts) + f" > {shell_quote(log_path)} 2>&1"
|
||||
|
||||
await ctx.info(f"启动 MD 长任务: {final_cmd} (工作目录: {workdir})")
|
||||
|
||||
files_before = await _listdir_safe(conn, workdir)
|
||||
proc = await _run_in_softbv_env_stream(conn, app_ctx.profile, cmd=final_cmd, cwd=workdir)
|
||||
start_time = time.monotonic()
|
||||
|
||||
# --- 3. 心跳循环,防止超时 ---
|
||||
try:
|
||||
while proc.exit_status is None:
|
||||
await asyncio.sleep(10) # 每 10 秒报告一次
|
||||
|
||||
elapsed = int(time.monotonic() - start_time)
|
||||
log_size = await _stat_size_safe(conn, log_path)
|
||||
|
||||
# 使用 ctx.report_progress 作为心跳机制 [1]
|
||||
await ctx.report_progress(
|
||||
progress=min(elapsed / (30 * 60), 0.99), # 以30分钟为基准的近似进度
|
||||
message=f"MD 计算进行中... 已用时 {elapsed} 秒, 日志大小: {log_size or 0} 字节。"
|
||||
)
|
||||
finally:
|
||||
await proc.wait() # 确保进程结束
|
||||
|
||||
# --- 4. 收集并返回结果 ---
|
||||
elapsed_seconds = int(time.monotonic() - start_time)
|
||||
files_after = await _listdir_safe(conn, workdir)
|
||||
new_files = sorted(set(files_after) - set(files_before))
|
||||
|
||||
# 读取日志文件的头和尾
|
||||
log_head, log_tail, full_log = "", "", ""
|
||||
try:
|
||||
async with conn.start_sftp_client() as sftp:
|
||||
async with sftp.open(log_path, "r", encoding="utf-8", errors="replace") as f:
|
||||
lines = await f.readlines()
|
||||
full_log = "".join(lines)
|
||||
log_head = "".join(lines[:40])
|
||||
log_tail = "".join(lines[-40:])
|
||||
except Exception as e:
|
||||
await ctx.warning(f"读取日志文件 '{log_path}' 失败: {e}")
|
||||
|
||||
# 从日志中解析最终属性
|
||||
final_properties = _parse_md_output(full_log)
|
||||
|
||||
result = MDResult(
|
||||
command_used=final_cmd,
|
||||
working_directory=workdir,
|
||||
exit_status=proc.exit_status,
|
||||
final_properties=final_properties,
|
||||
log_file=log_path,
|
||||
output_head=log_head,
|
||||
output_tail=log_tail,
|
||||
new_files=new_files,
|
||||
elapsed_seconds=elapsed_seconds,
|
||||
)
|
||||
|
||||
if result.exit_status == 0:
|
||||
await ctx.info(f"MD 任务成功完成, 耗时 {elapsed_seconds} 秒。")
|
||||
else:
|
||||
await ctx.error(f"MD 任务失败, 退出码: {result.exit_status}。请检查日志: {log_path}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# 在 softbv_mcp_refactored.py 的工具区域添加这个函数
|
||||
|
||||
@mcp.tool(name="softbv_run_kmc")
|
||||
async def softbv_run_kmc(
|
||||
args: KMCDefaultArgs | KMCFullArgs,
|
||||
cwd: str | None = None,
|
||||
ctx: Context[ServerSession, SoftBVContext] | None = None,
|
||||
) -> KMCResult:
|
||||
"""
|
||||
执行 'softBV.x --kmc' 进行动力学蒙特卡洛计算。这是一个耗时任务。
|
||||
|
||||
支持两种调用模式: 默认模式和全参数模式。
|
||||
"""
|
||||
if ctx is None:
|
||||
raise ValueError("Context is required for this operation.")
|
||||
|
||||
app_ctx = ctx.request_context.lifespan_context
|
||||
conn = app_ctx.ssh_connection
|
||||
workdir = cwd or app_ctx.workdir
|
||||
|
||||
# --- 1. 构造命令 ---
|
||||
def get_abs(path: str) -> str:
|
||||
return path if path.startswith("/") else posixpath.normpath(posixpath.join(workdir, path))
|
||||
|
||||
cmd_parts = [
|
||||
shell_quote(app_ctx.bin_path),
|
||||
shell_quote("--kmc"),
|
||||
shell_quote(get_abs(args.input_cif)),
|
||||
shell_quote(get_abs(args.input_cube)),
|
||||
shell_quote(args.type),
|
||||
shell_quote(str(args.os)),
|
||||
]
|
||||
|
||||
if isinstance(args, KMCFullArgs):
|
||||
if args.supercell:
|
||||
cmd_parts.extend(map(lambda x: shell_quote(str(x)), args.supercell))
|
||||
|
||||
# 使用一个循环来处理所有可选的浮点/整型参数
|
||||
for param in ['sf', 'temperature', 't_limit', 'step_limit', 'step_log', 'cutoff']:
|
||||
value = getattr(args, param, None)
|
||||
if value is not None:
|
||||
cmd_parts.append(shell_quote(str(value)))
|
||||
|
||||
# --- 2. 准备长时任务 (与 gen-cube 和 md 相同) ---
|
||||
import time
|
||||
log_name = f"softbv_kmc_{int(time.time())}.log"
|
||||
log_path = posixpath.join(workdir, log_name)
|
||||
|
||||
final_cmd = " ".join(cmd_parts) + f" > {shell_quote(log_path)} 2>&1"
|
||||
|
||||
await ctx.info(f"启动 KMC 长任务: {final_cmd} (工作目录: {workdir})")
|
||||
|
||||
files_before = await _listdir_safe(conn, workdir)
|
||||
proc = await _run_in_softbv_env_stream(conn, app_ctx.profile, cmd=final_cmd, cwd=workdir)
|
||||
start_time = time.monotonic()
|
||||
|
||||
# --- 3. 心跳循环 (与 gen-cube 和 md 相同) ---
|
||||
try:
|
||||
while proc.exit_status is None:
|
||||
await asyncio.sleep(10)
|
||||
elapsed = int(time.monotonic() - start_time)
|
||||
log_size = await _stat_size_safe(conn, log_path)
|
||||
await ctx.report_progress(
|
||||
progress=min(elapsed / (30 * 60), 0.99),
|
||||
message=f"KMC 计算进行中... 已用时 {elapsed} 秒, 日志大小: {log_size or 0} 字节。"
|
||||
)
|
||||
finally:
|
||||
await proc.wait()
|
||||
|
||||
# --- 4. 收集并返回结果 ---
|
||||
elapsed_seconds = int(time.monotonic() - start_time)
|
||||
files_after = await _listdir_safe(conn, workdir)
|
||||
new_files = sorted(set(files_after) - set(files_before))
|
||||
|
||||
log_head, log_tail, full_log = "", "", ""
|
||||
try:
|
||||
async with conn.start_sftp_client() as sftp:
|
||||
async with sftp.open(log_path, "r", encoding="utf-8", errors="replace") as f:
|
||||
lines = await f.readlines()
|
||||
full_log = "".join(lines)
|
||||
log_head = "".join(lines[:40])
|
||||
log_tail = "".join(lines[-40:])
|
||||
except Exception as e:
|
||||
await ctx.warning(f"读取日志文件 '{log_path}' 失败: {e}")
|
||||
|
||||
# 从完整日志中解析最终属性
|
||||
final_properties = _parse_kmc_output(full_log)
|
||||
|
||||
result = KMCResult(
|
||||
command_used=final_cmd,
|
||||
working_directory=workdir,
|
||||
exit_status=proc.exit_status,
|
||||
final_properties=final_properties,
|
||||
log_file=log_path,
|
||||
output_head=log_head,
|
||||
output_tail=log_tail,
|
||||
new_files=new_files,
|
||||
elapsed_seconds=elapsed_seconds,
|
||||
)
|
||||
|
||||
if result.exit_status == 0:
|
||||
await ctx.info(f"KMC 任务成功完成, 耗时 {elapsed_seconds} 秒。")
|
||||
else:
|
||||
await ctx.error(f"KMC 任务失败, 退出码: {result.exit_status}。请检查日志: {log_path}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ========= 6. 工厂函数与主程序入口 =========
|
||||
|
||||
def create_softbv_mcp() -> FastMCP:
|
||||
Reference in New Issue
Block a user