diff --git a/mcp/SearchPaperByEmbedding/LICENSE b/mcp/SearchPaperByEmbedding/LICENSE new file mode 100644 index 0000000..67ad045 --- /dev/null +++ b/mcp/SearchPaperByEmbedding/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 gyj155 + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/mcp/SearchPaperByEmbedding/README.md b/mcp/SearchPaperByEmbedding/README.md new file mode 100644 index 0000000..88c3496 --- /dev/null +++ b/mcp/SearchPaperByEmbedding/README.md @@ -0,0 +1,96 @@ +# Paper Semantic Search + +Find similar papers using semantic search. Supports both local models (free) and OpenAI API (better quality). + +## Features + +- Request for papers from OpenReview (e.g., ICLR2026 submissions) +- Semantic search with example papers or text queries +- Support embedding caching +- Embed model support: Open-source (e.g., all-MiniLM-L6-v2) or OpenAI + +## Quick Start + +```bash +pip install -r requirements.txt +``` + +### 1. Prepare Papers + +```python +from crawl import crawl_papers + +crawl_papers( + venue_id="ICLR.cc/2026/Conference/Submission", + output_file="iclr2026_papers.json" +) +``` + +### 2. Search Papers + +```python +from search import PaperSearcher + +# Local model (free) +searcher = PaperSearcher('iclr2026_papers.json', model_type='local') + +# OpenAI model (better, requires API key) +# export OPENAI_API_KEY='your-key' +# searcher = PaperSearcher('iclr2026_papers.json', model_type='openai') + +searcher.compute_embeddings() + +# Search with example papers that you are interested in +examples = [ + { + "title": "Your paper title", + "abstract": "Your paper abstract..." + } +] + +results = searcher.search(examples=examples, top_k=100) + +# Or search with text query +results = searcher.search(query="interesting topics", top_k=100) + +searcher.display(results, n=10) +searcher.save(results, 'results.json') +``` + + + +## How It Works + +1. Paper titles and abstracts are converted to embeddings +2. Embeddings are cached automatically +3. Your query is embedded using the same model +4. Cosine similarity finds the most similar papers +5. Results are ranked by similarity score + +## Cache + +Embeddings are cached as `cache___.npy`. Delete to recompute. + +## Example Output + +``` +================================================================================ +Top 100 Results (showing 10) +================================================================================ + +1. [0.8456] Paper a + #12345 | foundation or frontier models, including LLMs + https://openreview.net/forum?id=xxx + +2. [0.8234] Paper b + #12346 | applications to robotics, autonomy, planning + https://openreview.net/forum?id=yyy +``` + +## Tips + +- Use 1-5 example papers for best results, or a paragraph of description of your interested topic +- Local model is good enough for most cases +- OpenAI model for critical search (~$1 for 18k queries) + +If it's useful, please consider giving a star~ \ No newline at end of file diff --git a/mcp/SearchPaperByEmbedding/crawl.py b/mcp/SearchPaperByEmbedding/crawl.py new file mode 100644 index 0000000..931e10d --- /dev/null +++ b/mcp/SearchPaperByEmbedding/crawl.py @@ -0,0 +1,66 @@ +import requests +import json +import time + +def fetch_submissions(venue_id, offset=0, limit=1000): + url = "https://api2.openreview.net/notes" + params = { + "content.venueid": venue_id, + "details": "replyCount,invitation", + "limit": limit, + "offset": offset, + "sort": "number:desc" + } + headers = {"User-Agent": "Mozilla/5.0"} + response = requests.get(url, params=params, headers=headers) + response.raise_for_status() + return response.json() + +def crawl_papers(venue_id, output_file): + all_papers = [] + offset = 0 + limit = 1000 + + print(f"Fetching papers from {venue_id}...") + + while True: + data = fetch_submissions(venue_id, offset, limit) + notes = data.get("notes", []) + + if not notes: + break + + for note in notes: + paper = { + "id": note.get("id"), + "number": note.get("number"), + "title": note.get("content", {}).get("title", {}).get("value", ""), + "authors": note.get("content", {}).get("authors", {}).get("value", []), + "abstract": note.get("content", {}).get("abstract", {}).get("value", ""), + "keywords": note.get("content", {}).get("keywords", {}).get("value", []), + "primary_area": note.get("content", {}).get("primary_area", {}).get("value", ""), + "forum_url": f"https://openreview.net/forum?id={note.get('id')}" + } + all_papers.append(paper) + + print(f"Fetched {len(notes)} papers (total: {len(all_papers)})") + + if len(notes) < limit: + break + + offset += limit + time.sleep(0.5) + + with open(output_file, "w", encoding="utf-8") as f: + json.dump(all_papers, f, ensure_ascii=False, indent=2) + + print(f"\nTotal: {len(all_papers)} papers") + print(f"Saved to {output_file}") + return all_papers + +if __name__ == "__main__": + crawl_papers( + venue_id="ICLR.cc/2026/Conference/Submission", + output_file="iclr2026_papers.json" + ) + diff --git a/mcp/SearchPaperByEmbedding/demo.py b/mcp/SearchPaperByEmbedding/demo.py new file mode 100644 index 0000000..789ac97 --- /dev/null +++ b/mcp/SearchPaperByEmbedding/demo.py @@ -0,0 +1,22 @@ +from search import PaperSearcher + +# Use local model (free) +searcher = PaperSearcher('iclr2026_papers.json', model_type='local') + +# Or use OpenAI (better quality) +# searcher = PaperSearcher('iclr2026_papers.json', model_type='openai') + +searcher.compute_embeddings() + +examples = [ + { + "title": "Improving Developer Emotion Classification via LLM-Based Augmentation", + "abstract": "Detecting developer emotion in the informative data stream of technical commit messages..." + }, +] + +results = searcher.search(examples=examples, top_k=100) + +searcher.display(results, n=10) +searcher.save(results, 'results.json') + diff --git a/mcp/SearchPaperByEmbedding/requirements.txt b/mcp/SearchPaperByEmbedding/requirements.txt new file mode 100644 index 0000000..6746271 --- /dev/null +++ b/mcp/SearchPaperByEmbedding/requirements.txt @@ -0,0 +1,6 @@ +requests +numpy +scikit-learn +sentence-transformers +openai + diff --git a/mcp/SearchPaperByEmbedding/search.py b/mcp/SearchPaperByEmbedding/search.py new file mode 100644 index 0000000..3be06af --- /dev/null +++ b/mcp/SearchPaperByEmbedding/search.py @@ -0,0 +1,156 @@ +import json +import numpy as np +import os +import hashlib +from pathlib import Path +from sklearn.metrics.pairwise import cosine_similarity + +class PaperSearcher: + def __init__(self, papers_file, model_type="openai", api_key=None, base_url=None): + with open(papers_file, 'r', encoding='utf-8') as f: + self.papers = json.load(f) + + self.model_type = model_type + self.cache_file = self._get_cache_file(papers_file, model_type) + self.embeddings = None + + if model_type == "openai": + from openai import OpenAI + self.client = OpenAI( + api_key=api_key or os.getenv('OPENAI_API_KEY'), + base_url=base_url + ) + self.model_name = "text-embedding-3-large" + else: + from sentence_transformers import SentenceTransformer + self.model = SentenceTransformer('all-MiniLM-L6-v2') + self.model_name = "all-MiniLM-L6-v2" + + self._load_cache() + + def _get_cache_file(self, papers_file, model_type): + base_name = Path(papers_file).stem + file_hash = hashlib.md5(papers_file.encode()).hexdigest()[:8] + cache_name = f"cache_{base_name}_{file_hash}_{model_type}.npy" + return str(Path(papers_file).parent / cache_name) + + def _load_cache(self): + if os.path.exists(self.cache_file): + try: + self.embeddings = np.load(self.cache_file) + if len(self.embeddings) == len(self.papers): + print(f"Loaded cache: {self.embeddings.shape}") + return True + self.embeddings = None + except: + self.embeddings = None + return False + + def _save_cache(self): + np.save(self.cache_file, self.embeddings) + print(f"Saved cache: {self.cache_file}") + + def _create_text(self, paper): + parts = [] + if paper.get('title'): + parts.append(f"Title: {paper['title']}") + if paper.get('abstract'): + parts.append(f"Abstract: {paper['abstract']}") + if paper.get('keywords'): + kw = ', '.join(paper['keywords']) if isinstance(paper['keywords'], list) else paper['keywords'] + parts.append(f"Keywords: {kw}") + return ' '.join(parts) + + def _embed_openai(self, texts): + if isinstance(texts, str): + texts = [texts] + + embeddings = [] + batch_size = 100 + + for i in range(0, len(texts), batch_size): + batch = texts[i:i + batch_size] + response = self.client.embeddings.create(input=batch, model=self.model_name) + embeddings.extend([item.embedding for item in response.data]) + + return np.array(embeddings) + + def _embed_local(self, texts): + if isinstance(texts, str): + texts = [texts] + return self.model.encode(texts, show_progress_bar=len(texts) > 100) + + def compute_embeddings(self, force=False): + if self.embeddings is not None and not force: + print("Using cached embeddings") + return self.embeddings + + print(f"Computing embeddings ({self.model_name})...") + texts = [self._create_text(p) for p in self.papers] + + if self.model_type == "openai": + self.embeddings = self._embed_openai(texts) + else: + self.embeddings = self._embed_local(texts) + + print(f"Computed: {self.embeddings.shape}") + self._save_cache() + return self.embeddings + + def search(self, examples=None, query=None, top_k=100): + if self.embeddings is None: + self.compute_embeddings() + + if examples: + texts = [] + for ex in examples: + text = f"Title: {ex['title']}" + if ex.get('abstract'): + text += f" Abstract: {ex['abstract']}" + texts.append(text) + + if self.model_type == "openai": + embs = self._embed_openai(texts) + else: + embs = self._embed_local(texts) + + query_emb = np.mean(embs, axis=0).reshape(1, -1) + + elif query: + if self.model_type == "openai": + query_emb = self._embed_openai(query).reshape(1, -1) + else: + query_emb = self._embed_local(query).reshape(1, -1) + else: + raise ValueError("Provide either examples or query") + + similarities = cosine_similarity(query_emb, self.embeddings)[0] + top_indices = np.argsort(similarities)[::-1][:top_k] + + return [{ + 'paper': self.papers[idx], + 'similarity': float(similarities[idx]) + } for idx in top_indices] + + def display(self, results, n=10): + print(f"\n{'='*80}") + print(f"Top {len(results)} Results (showing {min(n, len(results))})") + print(f"{'='*80}\n") + + for i, result in enumerate(results[:n], 1): + paper = result['paper'] + sim = result['similarity'] + + print(f"{i}. [{sim:.4f}] {paper['title']}") + print(f" #{paper.get('number', 'N/A')} | {paper.get('primary_area', 'N/A')}") + print(f" {paper['forum_url']}\n") + + def save(self, results, output_file): + with open(output_file, 'w', encoding='utf-8') as f: + json.dump({ + 'model': self.model_name, + 'total': len(results), + 'results': results + }, f, ensure_ascii=False, indent=2) + print(f"Saved to {output_file}") + diff --git a/mcp/main.py b/mcp/main.py index 85e9af9..498292d 100644 --- a/mcp/main.py +++ b/mcp/main.py @@ -10,7 +10,7 @@ from starlette.routing import Mount from system_tools import create_system_mcp from materialproject_mcp import create_materials_mcp -from softBV import create_softbv_mcp +from softBV_remake import create_softbv_mcp # 创建 MCP 实例 system_mcp = create_system_mcp() materials_mcp = create_materials_mcp() diff --git a/mcp/softBV_reramke.py b/mcp/softBV_remake.py similarity index 64% rename from mcp/softBV_reramke.py rename to mcp/softBV_remake.py index 902dfbb..ba2533d 100644 --- a/mcp/softBV_reramke.py +++ b/mcp/softBV_remake.py @@ -121,15 +121,26 @@ class BVPairItem(BaseModel): occ2: float bv: float = Field(description="键价 (bond valence)") + class CalBVResult(BaseModel): - """'--cal-bv' 命令的完整结构化解析结果""" + """ + '--cal-bv' 命令的优化版结构化解析结果。 + 为了避免返回内容过大,只返回关键摘要和数据预览。 + """ global_instability_index: float | None = Field(None, description="全局不稳定性指数 (GII)") suggested_stability: str | None = Field(None, description="建议的稳定性 (例如 'stable')") - bv_sums: list[BVSumItem] = Field(description="每个离子的键价和列表") - bv_pairs: list[BVPairItem] = Field(description="离子对之间的键价列表") - raw_output: str = Field(description="命令的完整原始标准输出") + # 摘要信息 + total_bv_sums: int = Field(0, description="键价和 (BVS) 条目总数") + total_bv_pairs: int = Field(0, description="键价对 (BV pairs) 条目总数") + # 数据预览 (例如前 20 条) + bv_sums_preview: list[BVSumItem] = Field(description="键价和列表的预览") + bv_pairs_preview: list[BVPairItem] = Field(description="键价对列表的预览") + + # 原始输出仍然保留,以备 AI 需要时自行解析 + raw_output_head: str = Field(description="命令原始标准输出的前 50 行") + raw_output_tail: str = Field(description="命令原始标准输出的后 50 行") class CubeAtomInfo(BaseModel): """从 .cube 文件中解析出的单个原子信息""" index: int @@ -198,6 +209,125 @@ class CalTotEnResult(BaseModel): total_energy_eV: float | None = Field(None, description="计算出的总能量 (eV)") screening_factor_used: float | None = Field(None, description="计算中使用的 screening factor (sf)") raw_output: str = Field(description="命令的完整原始标准输出") + +class AnalyzePathwayArgs(BaseModel): + """'--gh' (pathway analysis) 命令的输入参数""" + input_cif: str = Field(description="用于标记位点的远程 CIF 文件路径") + input_cube: str = Field(description="包含势能数据的远程 Cube 文件路径") + type: str = Field(description="导电离子类型 (例如 'Li')") + os: int = Field(description="导电离子氧化态 (例如 1)") + barrier_max: float | None = Field(None, description="最大势垒,低于此值的路径才会被显示") + periodic: bool = Field(True, description="是否将 cube 文件视为周期性的") + +class PercolationThresholds(BaseModel): + """从 '--gh' 命令输出中解析的渗流阈值""" + e_total: float | None = Field(None, alias="e", description="总网络的阈值") + e_1D: float | None = Field(None, description="一维传导网络的阈值") + e_2D: float | None = Field(None, description="二维传导网络的阈值") + e_3D: float | None = Field(None, description="三维传导网络的阈值") + e_loop: float | None = Field(None, description="循环路径的阈值") + +class AnalyzePathwayResult(BaseModel): + """'--gh' 命令的完整结构化解析结果""" + command_used: str + working_directory: str + exit_status: int + thresholds: PercolationThresholds | None = Field(None, description="计算出的渗流阈值") + new_files: list[str] = Field(description="执行后新生成的文件列表 (例如 *.gh.cif)") + raw_output: str = Field(description="命令的完整原始标准输出") + + +class MDDefaultArgs(BaseModel): + """用于 '--md' 的默认参数模式""" + input_cif: str = Field(description="远程 CIF 文件路径") + type: str = Field(description="导电离子类型 (例如 'Li')") + os: int = Field(description="导电离子氧化态 (例如 1)") + +class MDFullArgs(BaseModel): + """用于 '--md' 的全参数模式""" + input_cif: str = Field(description="远程 CIF 文件路径") + type: str = Field(description="导电离子类型") + os: int = Field(description="导电离子氧化态") + sf: float = Field(description="自定义 screening factor") + temperature: float = Field(description="温度 (K)") + t_end: float = Field(description="生产时间 (ps)") + t_equil: float = Field(description="平衡时间 (ps)") + dt: float = Field(description="时间步长 (ps)") + t_log: float = Field(description="采样间隔 (ps)") + +class MDResultProperties(BaseModel): + """从 '--md' 输出中解析出的最终物理属性""" + displacement: float | None = None + diffusivity: float | None = None + mobility: float | None = None + conductivity: float | None = None + +class MDResult(BaseModel): + """'--md' 命令的完整结构化执行结果""" + command_used: str + working_directory: str + exit_status: int + final_properties: MDResultProperties | None = Field(None, description="计算出的最终电导率等属性") + log_file: str = Field(description="远程服务器上的完整日志文件路径") + output_head: str = Field(description="日志文件的前40行") + output_tail: str = Field(description="日志文件的后40行,包含最终结果") + new_files: list[str] = Field(description="执行后新生成的文件列表") + elapsed_seconds: int = Field(description="任务总耗时(秒)") + + +class KMCDefaultArgs(BaseModel): + """用于 '--kmc' 的默认参数模式""" + input_cif: str = Field(description="远程 CIF 文件路径") + input_cube: str = Field(description="远程 Cube 文件路径") + type: str = Field(description="导电离子类型 (例如 'Li')") + os: int = Field(description="导电离子氧化态 (例如 1)") + +class KMCFullArgs(BaseModel): + """用于 '--kmc' 的全参数模式""" + input_cif: str = Field(description="远程 CIF 文件路径") + input_cube: str = Field(description="远程 Cube 文件路径") + type: str = Field(description="导电离子类型") + os: int = Field(description="导电离子氧化态") + supercell: tuple[int, int, int] | None = Field(None, description="超胞维度 (s0, s1, s2)") + sf: float | None = Field(None, description="自定义 screening factor") + temperature: float | None = Field(None, description="温度 (K)") + t_limit: float | None = Field(None, description="时间限制 (ps)") + step_limit: int | None = Field(None, description="步数限制") + step_log: int | None = Field(None, description="日志记录步数间隔") + cutoff: float | None = Field(None, description="库仑排斥截断半径") + +class KMCResultVector(BaseModel): + """表示一个标量和一个矢量分量 (total, (x, y, z))""" + total: float + vector: tuple[float, float, float] + +class KMCSiteOccupancy(BaseModel): + """KMC 结果中的单个位点占据信息""" + site_name: str + occupancy: float + multiplicity: int + +class KMCResultProperties(BaseModel): + """从 '--kmc' 输出中解析出的最终物理属性""" + temperature: float | None = None + displacement: KMCResultVector | None = None + diffusivity: KMCResultVector | None = None + mobility: KMCResultVector | None = None + conductivity: KMCResultVector | None = None + site_occupancy_summary: list[KMCSiteOccupancy] = Field(default_factory=list) + +class KMCResult(BaseModel): + """'--kmc' 命令的完整结构化执行结果""" + command_used: str + working_directory: str + exit_status: int + final_properties: KMCResultProperties | None = Field(None, description="计算出的最终 KMC 属性") + log_file: str = Field(description="远程服务器上的完整日志文件路径") + output_head: str = Field(description="日志文件的前部内容") + output_tail: str = Field(description="日志文件的尾部内容,包含最终结果") + new_files: list[str] = Field(description="执行后新生成的文件列表") + elapsed_seconds: int = Field(description="任务总耗时(秒)") + # ========= 2. 辅助函数 ========= def shell_quote(arg: str) -> str: @@ -342,7 +472,7 @@ async def _stat_size_safe(conn: asyncssh.SSHClientConnection, path: str) -> int def _parse_cal_bv_output(raw_text: str) -> CalBVResult: """ - 将 '--cal-bv' 的原始 stdout 解析为结构化的 CalBVResult 模型。 + 将 '--cal-bv' 的原始 stdout 解析为优化的、包含摘要和预览的 CalBVResult 模型。 """ lines = raw_text.splitlines() @@ -351,10 +481,9 @@ def _parse_cal_bv_output(raw_text: str) -> CalBVResult: bv_sums = [] bv_pairs = [] + # ...(GII 和表格解析逻辑保持不变)... in_sum_section = False in_pair_section = False - - # 解析 GII 和稳定性 for line in lines: if line.startswith("GII: Global instability index ="): try: @@ -363,12 +492,10 @@ def _parse_cal_bv_output(raw_text: str) -> CalBVResult: pass elif line.startswith("GII: suggested stability:"): try: - stability = line.split(":", 1)[-1].strip() + stability = line.split(":", 1)[-1].strip().rstrip('.') except IndexError: pass - # 解析表格 - for line in lines: if line.startswith("BV: name type occ"): in_sum_section = True in_pair_section = False @@ -379,52 +506,45 @@ def _parse_cal_bv_output(raw_text: str) -> CalBVResult: continue if line.startswith("BV:="): continue - if not line.startswith("BV:"): in_sum_section = False in_pair_section = False - # 使用正则表达式进行更稳健的解析,以处理可变宽度的列 if in_sum_section: match = re.match(r"BV:\s+(\S+)\s+(\S+)\s+([\d.]+)\s+([-\d.]+)", line) if match: try: bv_sums.append(BVSumItem( - name=match.group(1), - type=match.group(2), - occ=float(match.group(3)), - bv_sum=float(match.group(4)) + name=match.group(1), type=match.group(2), + occ=float(match.group(3)), bv_sum=float(match.group(4)) )) except (ValueError, IndexError): continue elif in_pair_section: - # 这是一个更复杂的行,使用 shlex 更安全地分割 try: - # 移除 "BV:" 前缀并分割 parts = line[3:].strip().split() if len(parts) == 7: bv_pairs.append(BVPairItem( - name1=parts[0], - type1=parts[1], - occ1=float(parts[2]), - name2=parts[3], - type2=parts[4], - occ2=float(parts[5]), + name1=parts[0], type1=parts[1], occ1=float(parts[2]), + name2=parts[3], type2=parts[4], occ2=float(parts[5]), bv=float(parts[6]) )) except (ValueError, IndexError): continue + # --- 关键修改:返回摘要和预览,而不是完整列表 --- return CalBVResult( global_instability_index=gii, suggested_stability=stability, - bv_sums=bv_sums, - bv_pairs=bv_pairs, - raw_output=raw_text + total_bv_sums=len(bv_sums), + total_bv_pairs=len(bv_pairs), + bv_sums_preview=bv_sums[:20], # 只返回前 20 条作为预览 + bv_pairs_preview=bv_pairs[:20], # 只返回前 20 条作为预览 + raw_output_head="\n".join(lines[:50]), + raw_output_tail="\n".join(lines[-50:]) ) - def _parse_print_cube_output(raw_text: str) -> PrintCubeResult: """ 将 '--print-cube' 的原始 stdout 解析为结构化的 PrintCubeResult 模型。 @@ -579,6 +699,91 @@ def _parse_cal_tot_en_output(raw_text: str) -> CalTotEnResult: ) + +def _parse_gh_output(raw_text: str) -> PercolationThresholds | None: + """从 '--gh' 的 stdout 中解析 SOFTBV_GRAPH_NETWORK 行。""" + thresholds = {} + # 正则表达式用于匹配 "e_type = value" 格式 + pattern = re.compile(r"SOFTBV_GRAPH_NETWORK: (e(?:_1D|_2D|_3D|_loop)?)\s*=\s*([-\d.]+)") + + for line in raw_text.splitlines(): + match = pattern.search(line) + if match: + key = match.group(1) + value = float(match.group(2)) + + # 为了与 Pydantic 模型中的 alias 对应, 'e' 保持不变 + if key == "e": + thresholds["e"] = value + else: + thresholds[key] = value + + if not thresholds: + return None + + return PercolationThresholds.model_validate(thresholds) + + +def _parse_md_output(raw_text: str) -> MDResultProperties: + """从 '--md' 的 stdout 中解析最终的物理属性。""" + properties = {} + # 正则表达式匹配 "MD: key = value" 格式 + pattern = re.compile(r"MD: (displacement|diffusivity|mobility|conductivity)\s*=\s*([-\d.eE]+)") + + for line in raw_text.splitlines(): + match = pattern.search(line) + if match: + key = match.group(1) + value = float(match.group(2)) + properties[key] = value + + return MDResultProperties.model_validate(properties) + +# 在 softbv_mcp_refactored.py 的辅助函数区域添加 + +def _parse_kmc_output(raw_text: str) -> KMCResultProperties: + """从 '--kmc' 的 stdout 中解析最终的物理属性和位点占据信息。""" + properties = {"site_occupancy_summary": []} + in_occupancy_summary = False + + # 匹配 "key = total, (x, y, z)" 或 "key = total" 的行 + vector_pattern = re.compile( + r"KMC: (temperature|displacement|diffusivity|mobility|conductivity)\s*=\s*([-\d.eE]+)(?:,\s*\(([-\d.eE]+),\s*([-\d.eE]+),\s*([-\d.eE]+)\))?" + ) + # 匹配位点占据信息的行 + occupancy_pattern = re.compile( + r"KMC:\s*\[\s*([^\]]+)\]\s+([-\d.]+)\s+\(multiplicity\s*=\s*(\d+)\)" + ) + + for line in raw_text.splitlines(): + if "KMC: summary of site-averaged occupancy:" in line: + in_occupancy_summary = True + continue + + vec_match = vector_pattern.search(line) + if vec_match: + key, total_str, x_str, y_str, z_str = vec_match.groups() + total = float(total_str) + if key == 'temperature': + properties[key] = total + elif x_str is not None: + properties[key] = { + "total": total, + "vector": (float(x_str), float(y_str), float(z_str)) + } + continue + + if in_occupancy_summary: + occ_match = occupancy_pattern.search(line) + if occ_match: + site, occ, mult = occ_match.groups() + properties["site_occupancy_summary"].append({ + "site_name": site.strip(), + "occupancy": float(occ), + "multiplicity": int(mult) + }) + + return KMCResultProperties.model_validate(properties) # ========= 3. 生命周期管理 ========= @dataclass @@ -865,6 +1070,8 @@ async def softbv_gen_cube( # 在 softbv_mcp_refactored.py 中,紧跟在 softbv_gen_cube 之后添加 +# 在 softbv_mcp_refactored.py 中,找到并替换 softbv_calculate_bv 函数 + @mcp.tool() async def softbv_calculate_bv( cif_path: str, @@ -872,15 +1079,8 @@ async def softbv_calculate_bv( ctx: Context[ServerSession, SoftBVContext] | None = None, ) -> CalBVResult: """ - 执行 'softBV.x --cal-bv',计算并返回键价和 (Bond Valence Sums)。 - - Args: - cif_path (str): 远程服务器上的 CIF 文件路径(可以是相对或绝对路径)。 - cwd (str | None): 远程工作目录。如果未提供,则使用默认沙箱目录。 - ctx: MCP 上下文,由框架自动注入。 - - Returns: - CalBVResult: 一个包含解析后的键价和、全局不稳定性指数以及原始输出的结构化对象。 + 执行 'softBV.x --cal-bv',计算键价和。 + 返回一个包含关键摘要和数据预览的结构化对象,以避免内容过长。 """ if ctx is None: raise ValueError("Context is required for this operation.") @@ -889,7 +1089,6 @@ async def softbv_calculate_bv( conn = app_ctx.ssh_connection workdir = cwd or app_ctx.workdir - # 构建路径和命令 input_abs_path = cif_path if not input_abs_path.startswith("/"): input_abs_path = posixpath.normpath(posixpath.join(workdir, cif_path)) @@ -898,27 +1097,23 @@ async def softbv_calculate_bv( await ctx.info(f"执行命令: {cmd} (工作目录: {workdir})") - # 执行命令 proc = await run_in_softbv_env(conn, app_ctx.profile, cmd=cmd, cwd=workdir, check=False) if proc.exit_status != 0: await ctx.warning(f"命令执行失败,退出码: {proc.exit_status}") - # 即使失败,也返回一个包含错误信息的结构 + # --- 关键修改:失败时返回与新模型匹配的空结构 --- + raw_error = f"Exit Status: {proc.exit_status}\n\nSTDOUT:\n{proc.stdout}\n\nSTDERR:\n{proc.stderr}" return CalBVResult( - global_instability_index=None, - suggested_stability=None, - bv_sums=[], - bv_pairs=[], - raw_output=f"Exit Status: {proc.exit_status}\n\nSTDOUT:\n{proc.stdout}\n\nSTDERR:\n{proc.stderr}" + bv_sums_preview=[], + bv_pairs_preview=[], + raw_output_head=raw_error[:1000], # 截断以防错误信息也过长 + raw_output_tail="" ) - # 解析成功的输出并返回结构化数据 + # 解析成功的输出并返回包含摘要和预览的结构化数据 parsed_result = _parse_cal_bv_output(proc.stdout) return parsed_result - -# 在 softbv_mcp_refactored.py 中,紧跟在 softbv_calculate_bv 之后添加 - @mcp.tool() async def softbv_print_cube( cube_path: str, @@ -1144,6 +1339,310 @@ async def softbv_calculate_total_energy( # 解析并返回结果 return _parse_cal_tot_en_output(proc.stdout) + +@mcp.tool() +async def softbv_analyze_pathway( + args: AnalyzePathwayArgs, + cwd: str | None = None, + ctx: Context[ServerSession, SoftBVContext] | None = None, +) -> AnalyzePathwayResult: + """ + 执行 'softBV.x --gh',分析传导路径并返回渗流阈值。 + + Args: + args (AnalyzePathwayArgs): 包含所有输入参数的结构化对象。 + cwd (str | None): 远程工作目录。 + ctx: MCP 上下文,由框架自动注入。 + + Returns: + AnalyzePathwayResult: 包含渗流阈值、新生成文件列表和原始输出的结构化结果。 + """ + if ctx is None: + raise ValueError("Context is required for this operation.") + + app_ctx = ctx.request_context.lifespan_context + conn = app_ctx.ssh_connection + workdir = cwd or app_ctx.workdir + + # --- 1. 构建命令 --- + def get_abs(path: str) -> str: + return path if path.startswith("/") else posixpath.normpath(posixpath.join(workdir, path)) + + cmd_parts = [ + shell_quote(app_ctx.bin_path), + shell_quote("--gh"), + shell_quote(get_abs(args.input_cif)), + shell_quote(get_abs(args.input_cube)), + shell_quote(args.type), + shell_quote(str(args.os)), + ] + if args.barrier_max is not None: + cmd_parts.append(shell_quote(str(args.barrier_max))) + + periodic_flag = "t" if args.periodic else "f" + cmd_parts.append(shell_quote(periodic_flag)) + + cmd = " ".join(cmd_parts) + + await ctx.info(f"执行路径分析: {cmd}") + + # --- 2. 执行并收集结果 --- + files_before = await _listdir_safe(conn, workdir) + proc = await run_in_softbv_env(conn, app_ctx.profile, cmd=cmd, cwd=workdir, check=False) + files_after = await _listdir_safe(conn, workdir) + + new_files = sorted(set(files_after) - set(files_before)) + + if proc.exit_status != 0: + await ctx.warning(f"'--gh' 命令执行失败,退出码: {proc.exit_status}") + return AnalyzePathwayResult( + command_used=cmd, + working_directory=workdir, + exit_status=proc.exit_status, + thresholds=None, + new_files=new_files, + raw_output=f"Exit Status: {proc.exit_status}\n\nSTDOUT:\n{proc.stdout}\n\nSTDERR:\n{proc.stderr}" + ) + + # --- 3. 解析并返回结构化结果 --- + parsed_thresholds = _parse_gh_output(proc.stdout) + if not parsed_thresholds: + await ctx.info("在输出中未找到渗流阈值信息。") + + return AnalyzePathwayResult( + command_used=cmd, + working_directory=workdir, + exit_status=proc.exit_status, + thresholds=parsed_thresholds, + new_files=new_files, + raw_output=proc.stdout + ) + + +# 在 softbv_mcp_refactored.py 的工具区域添加这个函数 + +@mcp.tool(name="softbv_run_md") +async def softbv_run_md( + args: MDDefaultArgs | MDFullArgs, + cwd: str | None = None, + ctx: Context[ServerSession, SoftBVContext] | None = None, +) -> MDResult: + """ + 执行 'softBV.x --md' 进行分子动力学计算。这是一个耗时任务。 + + 支持两种调用模式: + 1. 默认模式 (MDDefaultArgs): 仅提供 CIF、离子类型和氧化态。 + 2. 全参数模式 (MDFullArgs): 提供所有 MD 参数。 + + 该工具会定期报告心跳,并在任务完成后返回包含电导率等关键信息的结构化结果。 + """ + if ctx is None: + raise ValueError("Context is required for this operation.") + + app_ctx = ctx.request_context.lifespan_context + conn = app_ctx.ssh_connection + workdir = cwd or app_ctx.workdir + + # --- 1. 构造命令 --- + input_abs_path = args.input_cif + if not input_abs_path.startswith("/"): + input_abs_path = posixpath.normpath(posixpath.join(workdir, args.input_cif)) + + cmd_parts = [ + shell_quote(app_ctx.bin_path), + shell_quote("--md"), + shell_quote(input_abs_path), + shell_quote(args.type), + shell_quote(str(args.os)), + ] + + if isinstance(args, MDFullArgs): + cmd_parts.extend([ + shell_quote(str(args.sf)), + shell_quote(str(args.temperature)), + shell_quote(str(args.t_end)), + shell_quote(str(args.t_equil)), + shell_quote(str(args.dt)), + shell_quote(str(args.t_log)), + ]) + + # --- 2. 准备长时任务 --- + import time + log_name = f"softbv_md_{int(time.time())}.log" + log_path = posixpath.join(workdir, log_name) + + final_cmd = " ".join(cmd_parts) + f" > {shell_quote(log_path)} 2>&1" + + await ctx.info(f"启动 MD 长任务: {final_cmd} (工作目录: {workdir})") + + files_before = await _listdir_safe(conn, workdir) + proc = await _run_in_softbv_env_stream(conn, app_ctx.profile, cmd=final_cmd, cwd=workdir) + start_time = time.monotonic() + + # --- 3. 心跳循环,防止超时 --- + try: + while proc.exit_status is None: + await asyncio.sleep(10) # 每 10 秒报告一次 + + elapsed = int(time.monotonic() - start_time) + log_size = await _stat_size_safe(conn, log_path) + + # 使用 ctx.report_progress 作为心跳机制 [1] + await ctx.report_progress( + progress=min(elapsed / (30 * 60), 0.99), # 以30分钟为基准的近似进度 + message=f"MD 计算进行中... 已用时 {elapsed} 秒, 日志大小: {log_size or 0} 字节。" + ) + finally: + await proc.wait() # 确保进程结束 + + # --- 4. 收集并返回结果 --- + elapsed_seconds = int(time.monotonic() - start_time) + files_after = await _listdir_safe(conn, workdir) + new_files = sorted(set(files_after) - set(files_before)) + + # 读取日志文件的头和尾 + log_head, log_tail, full_log = "", "", "" + try: + async with conn.start_sftp_client() as sftp: + async with sftp.open(log_path, "r", encoding="utf-8", errors="replace") as f: + lines = await f.readlines() + full_log = "".join(lines) + log_head = "".join(lines[:40]) + log_tail = "".join(lines[-40:]) + except Exception as e: + await ctx.warning(f"读取日志文件 '{log_path}' 失败: {e}") + + # 从日志中解析最终属性 + final_properties = _parse_md_output(full_log) + + result = MDResult( + command_used=final_cmd, + working_directory=workdir, + exit_status=proc.exit_status, + final_properties=final_properties, + log_file=log_path, + output_head=log_head, + output_tail=log_tail, + new_files=new_files, + elapsed_seconds=elapsed_seconds, + ) + + if result.exit_status == 0: + await ctx.info(f"MD 任务成功完成, 耗时 {elapsed_seconds} 秒。") + else: + await ctx.error(f"MD 任务失败, 退出码: {result.exit_status}。请检查日志: {log_path}") + + return result + + +# 在 softbv_mcp_refactored.py 的工具区域添加这个函数 + +@mcp.tool(name="softbv_run_kmc") +async def softbv_run_kmc( + args: KMCDefaultArgs | KMCFullArgs, + cwd: str | None = None, + ctx: Context[ServerSession, SoftBVContext] | None = None, +) -> KMCResult: + """ + 执行 'softBV.x --kmc' 进行动力学蒙特卡洛计算。这是一个耗时任务。 + + 支持两种调用模式: 默认模式和全参数模式。 + """ + if ctx is None: + raise ValueError("Context is required for this operation.") + + app_ctx = ctx.request_context.lifespan_context + conn = app_ctx.ssh_connection + workdir = cwd or app_ctx.workdir + + # --- 1. 构造命令 --- + def get_abs(path: str) -> str: + return path if path.startswith("/") else posixpath.normpath(posixpath.join(workdir, path)) + + cmd_parts = [ + shell_quote(app_ctx.bin_path), + shell_quote("--kmc"), + shell_quote(get_abs(args.input_cif)), + shell_quote(get_abs(args.input_cube)), + shell_quote(args.type), + shell_quote(str(args.os)), + ] + + if isinstance(args, KMCFullArgs): + if args.supercell: + cmd_parts.extend(map(lambda x: shell_quote(str(x)), args.supercell)) + + # 使用一个循环来处理所有可选的浮点/整型参数 + for param in ['sf', 'temperature', 't_limit', 'step_limit', 'step_log', 'cutoff']: + value = getattr(args, param, None) + if value is not None: + cmd_parts.append(shell_quote(str(value))) + + # --- 2. 准备长时任务 (与 gen-cube 和 md 相同) --- + import time + log_name = f"softbv_kmc_{int(time.time())}.log" + log_path = posixpath.join(workdir, log_name) + + final_cmd = " ".join(cmd_parts) + f" > {shell_quote(log_path)} 2>&1" + + await ctx.info(f"启动 KMC 长任务: {final_cmd} (工作目录: {workdir})") + + files_before = await _listdir_safe(conn, workdir) + proc = await _run_in_softbv_env_stream(conn, app_ctx.profile, cmd=final_cmd, cwd=workdir) + start_time = time.monotonic() + + # --- 3. 心跳循环 (与 gen-cube 和 md 相同) --- + try: + while proc.exit_status is None: + await asyncio.sleep(10) + elapsed = int(time.monotonic() - start_time) + log_size = await _stat_size_safe(conn, log_path) + await ctx.report_progress( + progress=min(elapsed / (30 * 60), 0.99), + message=f"KMC 计算进行中... 已用时 {elapsed} 秒, 日志大小: {log_size or 0} 字节。" + ) + finally: + await proc.wait() + + # --- 4. 收集并返回结果 --- + elapsed_seconds = int(time.monotonic() - start_time) + files_after = await _listdir_safe(conn, workdir) + new_files = sorted(set(files_after) - set(files_before)) + + log_head, log_tail, full_log = "", "", "" + try: + async with conn.start_sftp_client() as sftp: + async with sftp.open(log_path, "r", encoding="utf-8", errors="replace") as f: + lines = await f.readlines() + full_log = "".join(lines) + log_head = "".join(lines[:40]) + log_tail = "".join(lines[-40:]) + except Exception as e: + await ctx.warning(f"读取日志文件 '{log_path}' 失败: {e}") + + # 从完整日志中解析最终属性 + final_properties = _parse_kmc_output(full_log) + + result = KMCResult( + command_used=final_cmd, + working_directory=workdir, + exit_status=proc.exit_status, + final_properties=final_properties, + log_file=log_path, + output_head=log_head, + output_tail=log_tail, + new_files=new_files, + elapsed_seconds=elapsed_seconds, + ) + + if result.exit_status == 0: + await ctx.info(f"KMC 任务成功完成, 耗时 {elapsed_seconds} 秒。") + else: + await ctx.error(f"KMC 任务失败, 退出码: {result.exit_status}。请检查日志: {log_path}") + + return result + + # ========= 6. 工厂函数与主程序入口 ========= def create_softbv_mcp() -> FastMCP: