Compare commits

...

19 Commits

Author SHA1 Message Date
95d719cc1e 对比学习法增改 2025-10-29 11:39:30 +08:00
1f8667ae51 新增了zeo++的部分 2025-10-23 17:08:57 +08:00
c0b2ec5983 sofvBV_mcp重构v2
Embedding copy
2025-10-22 23:59:23 +08:00
b9ba79d7a8 sofvBV_mcp重构 2025-10-22 16:52:56 +08:00
bd4bd3a645 LiYCl
朱一舟
2025-10-16 16:49:33 +08:00
c20d752faa LiYCl
朱一舟
2025-10-16 16:22:13 +08:00
4e21954471 mcp-softBV 2025-10-16 14:57:27 +08:00
fc7716507a mcp-system-fix 2025-10-15 16:19:41 +08:00
b11bf2417b mcp-gen-material 2025-10-15 14:43:57 +08:00
9b0f835575 mcp
成功使用安全版本实现系统相关操作
2025-10-14 09:52:02 +08:00
b19352382b mcp-python 2025-10-13 15:37:31 +08:00
28639f9cbf Merge remote-tracking branch 'origin/master' 2025-10-13 13:09:22 +08:00
41a6038e50 mcp-uvicorn 2025-10-13 13:09:06 +08:00
f625154aee mcp-python 2025-10-13 10:08:59 +08:00
f585d76cac mcp-change 2025-10-09 15:40:49 +08:00
c8629619ee mcp 2025-10-09 13:33:26 +08:00
5d1a4d04f2 mcp 2025-10-09 09:43:34 +08:00
e6141689c1 softBV-calulate-first 2025-09-24 11:57:52 +08:00
efcdacffd0 Corner-sharing 2025-09-23 19:39:54 +08:00
34 changed files with 5568 additions and 11 deletions

95
Screen/process_txt.py Normal file
View File

@@ -0,0 +1,95 @@
# -*- coding: utf-8 -*-
import os
import re
import sys
import csv
def extract_data_from_folder(folder_path):
"""
读取文件夹中所有txt文件提取指定信息并写入一个CSV文件。
(已修正正则表达式)
:param folder_path: 包含txt文件的文件夹路径。
"""
# 检查文件夹是否存在
if not os.path.isdir(folder_path):
print(f"错误:文件夹 '{folder_path}' 不存在。")
return
# 定义要提取数据的正则表达式
# 1. 对应 "Percolation diameter (A): 1.234"
pattern1 = re.compile(r"Percolation diameter \(A\): ([\d\.]+)")
# 2. 对应 "the minium of d 1.23 #" (已根据您的反馈修正)
pattern2 = re.compile(r"the minium of d\s*([\d\.]+)\s*#")
# 3. 对应 "Maximum node length detected: 5.67 A"
pattern3 = re.compile(r"Maximum node length detected: ([\d\.]+) A")
# 存储所有提取到的数据
all_data = []
# 遍历文件夹中的所有文件,使用 sorted() 确保处理顺序一致
for filename in sorted(os.listdir(folder_path)):
if filename.endswith(".txt"):
txt_path = os.path.join(folder_path, filename)
try:
with open(txt_path, 'r', encoding='utf-8') as file:
content = file.read()
# 使用修正后的正则表达式查找数据
match1 = pattern1.search(content)
match2 = pattern2.search(content)
match3 = pattern3.search(content)
# 提取匹配到的值,如果未匹配到则为空字符串 ''
val1 = match1.group(1) if match1 else ''
val2 = match2.group(1) if match2 else ''
val3 = match3.group(1) if match3 else ''
# 获取文件名(不含.txt后缀
base_filename = os.path.splitext(filename)[0]
# 将这一行的数据添加到总列表中
all_data.append([base_filename, val1, val2, val3])
except Exception as e:
print(f"处理文件 {filename} 时发生错误: {e}")
# 如果没有找到任何txt文件或数据则不创建csv
if not all_data:
print("未在文件夹中找到任何 .txt 文件或未能提取任何数据。")
return
# 根据文件夹名确定CSV文件名
folder_name = os.path.basename(os.path.normpath(folder_path))
csv_filename = f"{folder_name}.csv"
# 写入CSV文件
try:
with open(csv_filename, 'w', newline='', encoding='utf-8') as csvfile:
writer = csv.writer(csvfile)
# 写入表头
headers = ['filename', 'Percolation_diameter_A', 'minium_of_d', 'Max_node_length_A']
writer.writerow(headers)
# 写入所有数据
writer.writerows(all_data)
print(f"数据成功写入到文件: {csv_filename}")
except Exception as e:
print(f"写入CSV文件 {csv_filename} 时发生错误: {e}")
if __name__ == "__main__":
if len(sys.argv) != 2:
print("用法: python your_script_name.py <folder_path>")
sys.exit(1)
input_folder = sys.argv[1]
extract_data_from_folder(input_folder)

151
contrast learning/copy.py Normal file
View File

@@ -0,0 +1,151 @@
import shutil
from pathlib import Path
def find_element_column_index(cif_lines: list) -> int:
"""
在CIF文件内容中查找 _atom_site_type_symbol 所在的列索引。
:param cif_lines: 从CIF文件读取的行列表。
:return: 元素符号列的索引从0开始如果未找到则返回-1。
"""
in_loop_header = False
column_index = -1
current_column = 0
for line in cif_lines:
line_stripped = line.strip()
if not line_stripped:
continue
if line_stripped.startswith('loop_'):
in_loop_header = True
column_index = -1
current_column = 0
continue
if in_loop_header:
if line_stripped.startswith('_'):
if line_stripped.startswith('_atom_site_type_symbol'):
column_index = current_column
current_column += 1
else:
# loop_ 头部定义结束,开始数据行
return column_index
return -1 # 如果文件中没有找到 loop_ 或 _atom_site_type_symbol
def copy_cif_with_O_or_S_robust(source_dir: str, target_dir: str, dry_run: bool = False):
"""
从源文件夹中筛选出内容包含'O''S'元素的CIF文件并复制到目标文件夹。
(鲁棒版能正确解析CIF中的元素符号列)
:param source_dir: 源文件夹路径包含CIF文件。
:param target_dir: 目标文件夹路径,用于存放筛选出的文件。
:param dry_run: 如果为True则只打印将要复制的文件而不实际执行复制操作。
"""
# 1. 路径处理和验证
source_path = Path(source_dir)
target_path = Path(target_dir)
if not source_path.is_dir():
print(f"错误:源文件夹 '{source_dir}' 不存在或不是一个文件夹。")
return
if not dry_run and not target_path.exists():
target_path.mkdir(parents=True, exist_ok=True)
print(f"目标文件夹 '{target_dir}' 已创建。")
print(f"源文件夹: {source_path}")
print(f"目标文件夹: {target_path}")
if dry_run:
print("\n--- *** 模拟运行模式 (Dry Run) *** ---")
print("--- 不会执行任何实际的文件复制操作 ---")
# 2. 开始遍历和筛选
print("\n开始扫描源文件夹中的CIF文件...")
copied_count = 0
checked_files = 0
error_files = 0
# 使用 rglob('*.cif') 可以遍历所有子文件夹,如果只想遍历当前文件夹用 glob
for file_path in source_path.glob('*.cif'):
if file_path.is_file():
checked_files += 1
try:
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
lines = f.readlines()
# 步骤 A: 找到元素符号在哪一列
element_col_idx = find_element_column_index(lines)
if element_col_idx == -1:
# 在某些CIF文件中可能没有loop块而是简单的 key-value 格式
# 为了兼容这种情况,我们保留一个简化的检查
found_simple = any(
line.strip().startswith(('_chemical_formula_sum', '_chemical_formula_structural')) and (
' O' in line or ' S' in line) for line in lines)
if not found_simple:
continue # 如果两种方法都找不到,跳过此文件
# 步骤 B: 检查该列是否有 'O' 或 'S'
found = False
for line in lines:
line_stripped = line.strip()
# 忽略空行、注释行和定义行
if not line_stripped or line_stripped.startswith(('#', '_', 'loop_')):
continue
parts = line_stripped.split()
# 确保行中有足够的列
if len(parts) > element_col_idx:
# 元素符号可能带有电荷,如 O2-,所以用 startswith
atom_symbol = parts[element_col_idx].strip()
if atom_symbol == 'O' or atom_symbol == 'S':
found = True
break
# 兼容性检查:如果通过了 found_simple 的检查,也标记为找到
if found_simple:
found = True
if found:
target_file_path = target_path / file_path.name
print(f"找到匹配: '{file_path.name}' (含有 O 或 S 元素)")
if not dry_run:
shutil.copy2(file_path, target_file_path)
# print(f" -> 已复制到 {target_file_path}") # 可以取消注释以获得更详细的输出
copied_count += 1
except Exception as e:
error_files += 1
print(f"!! 处理文件 '{file_path.name}' 时发生错误: {e}")
# 3. 打印最终报告
print("\n--- 操作总结 ---")
print(f"共检查了 {checked_files} 个.cif文件。")
if error_files > 0:
print(f"处理过程中有 {error_files} 个文件发生错误。")
if dry_run:
print(f"模拟运行结束:如果实际运行,将会有 {copied_count} 个文件被复制。")
else:
print(f"成功复制了 {copied_count} 个文件到目标文件夹。")
if __name__ == '__main__':
# !! 重要:请将下面的路径修改为您自己电脑上的实际路径
source_folder = "D:/download/2025-10/data_all/input/input"
target_folder = "D:/download/2025-10/data_all/output"
# --- 第一次运行:使用模拟模式 (Dry Run) ---
print("================ 第一次运行: 模拟模式 ================")
copy_cif_with_O_or_S_robust(source_folder, target_folder, dry_run=True)
print("\n\n=======================================================")
input("检查上面的模拟运行结果。如果符合预期,按回车键继续执行实际复制操作...")
print("=======================================================")
# --- 第二次运行:实际执行复制 ---
print("\n================ 第二次运行: 实际复制模式 ================")
copy_cif_with_O_or_S_robust(source_folder, target_folder, dry_run=False)

111
contrast learning/delete.py Normal file
View File

@@ -0,0 +1,111 @@
import shutil
from pathlib import Path
def delete_duplicates_from_second_folder(source_dir: str, target_dir: str, dry_run: bool = False):
"""
删除第二个文件夹中与第一个文件夹内项目同名的文件或文件夹。
:param source_dir: 第一个文件夹(源)的路径。
:param target_dir: 第二个文件夹(目标)的路径,将从此文件夹中删除内容。
:param dry_run: 如果为True则只打印将要删除的内容而不实际执行删除操作。
"""
# 1. 将字符串路径转换为Path对象方便操作
source_path = Path(source_dir)
target_path = Path(target_dir)
# 2. 验证路径是否存在且为文件夹
if not source_path.is_dir():
print(f"错误:源文件夹 '{source_dir}' 不存在或不是一个文件夹。")
return
if not target_path.is_dir():
print(f"错误:目标文件夹 '{target_dir}' 不存在或不是一个文件夹。")
return
print(f"源文件夹: {source_path}")
print(f"目标文件夹: {target_path}")
if dry_run:
print("\n--- *** 模拟运行模式 (Dry Run) *** ---")
print("--- 不会执行任何实际的删除操作 ---")
# 3. 获取源文件夹中所有项目(文件和子文件夹)的名称
# p.name 会返回路径的最后一部分,即文件名或文件夹名
source_item_names = {p.name for p in source_path.iterdir()}
if not source_item_names:
print("\n源文件夹为空,无需执行任何操作。")
return
print(f"\n在源文件夹中找到 {len(source_item_names)} 个项目。")
print("开始检查并删除目标文件夹中的同名项目...")
deleted_count = 0
# 4. 遍历源文件夹中的项目名称
for item_name in source_item_names:
# 构建目标文件夹中可能存在的同名项目的完整路径
item_to_delete = target_path / item_name
# 5. 检查该项目是否存在于目标文件夹中
if item_to_delete.exists():
try:
if item_to_delete.is_file():
# 如果是文件,直接删除
print(f"准备删除文件: {item_to_delete}")
if not dry_run:
item_to_delete.unlink()
print(" -> 已删除。")
deleted_count += 1
elif item_to_delete.is_dir():
# 如果是文件夹,使用 shutil.rmtree 删除整个文件夹及其内容
print(f"准备删除文件夹及其所有内容: {item_to_delete}")
if not dry_run:
shutil.rmtree(item_to_delete)
print(" -> 已删除。")
deleted_count += 1
except Exception as e:
print(f"!! 删除 '{item_to_delete}' 时发生错误: {e}")
if deleted_count == 0:
print("\n操作完成:在目标文件夹中没有找到需要删除的同名项目。")
else:
if dry_run:
print(f"\n模拟运行结束:如果实际运行,将会有 {deleted_count} 个项目被删除。")
else:
print(f"\n操作完成:总共删除了 {deleted_count} 个项目。")
# --- 使用示例 ---
# 在运行前,请创建以下文件夹和文件结构进行测试:
# /your/path/folder1/
# ├── file_a.txt
# ├── file_b.log
# └── subfolder_x/
# └── test.txt
# /your/path/folder2/
# ├── file_a.txt (将被删除)
# ├── file_c.md
# └── subfolder_x/ (将被删除)
# └── another.txt
if __name__ == '__main__':
# !! 重要:请将下面的路径修改为您自己电脑上的实际路径
folder1_path = "D:/download/2025-10/after_step5/after_step5/S" # 源文件夹
folder2_path = "D:/download/2025-10/input/input" # 目标文件夹
# --- 第一次运行:使用模拟模式 (Dry Run),非常推荐!---
# 这会告诉你脚本将要做什么,但不会真的删除任何东西。
print("================ 第一次运行: 模拟模式 ================")
delete_duplicates_from_second_folder(folder1_path, folder2_path, dry_run=True)
print("\n\n=======================================================")
input("检查上面的模拟运行结果。如果符合预期,按回车键继续执行实际删除操作...")
print("=======================================================")
# --- 第二次运行:实际执行删除 ---
# 确认模拟运行结果无误后,再将 dry_run 设置为 False 或移除该参数。
print("\n================ 第二次运行: 实际删除模式 ================")
delete_duplicates_from_second_folder(folder1_path, folder2_path, dry_run=False)

92
corner-sharing/0923_CS.py Normal file
View File

@@ -0,0 +1,92 @@
import os
import csv
from pymatgen.core import Structure
from tqdm import tqdm # 引入tqdm来显示进度条如果未安装请运行 pip install tqdm
# --- 导入您自定义的分析函数 ---
# 假设您的函数存放在 'utils/CS_analyse.py' 文件中
# 并且您已经将它们重命名
# from calculate_polyhedra_sharing import calculate_polyhedra_sharing as CS_catulate
# from check_only_corner_sharing import check_only_corner_sharing
# 注意:根据您的描述,您会将函数放在 utils 文件夹中,因此导入方式如下:
from utils.CS_analyse import CS_catulate, check_only_corner_sharing
def process_cif_folder(cif_folder_path: str, output_csv_path: str):
"""
遍历指定文件夹中的所有CIF文件计算其角共享特性并将结果输出到CSV文件。
参数:
cif_folder_path (str): 存放CIF文件的文件夹路径。
output_csv_path (str): 输出的CSV文件的路径。
"""
# 检查输入文件夹是否存在
if not os.path.isdir(cif_folder_path):
print(f"错误: 文件夹 '{cif_folder_path}' 不存在。")
return
# 准备存储结果的列表
results = []
# 获取所有CIF文件的列表
try:
cif_files = [f for f in os.listdir(cif_folder_path) if f.endswith('.cif')]
if not cif_files:
print(f"警告: 在文件夹 '{cif_folder_path}' 中没有找到任何 .cif 文件。")
return
except FileNotFoundError:
print(f"错误: 无法访问文件夹 '{cif_folder_path}'")
return
print(f"开始处理 {len(cif_files)} 个CIF文件...")
# 使用tqdm创建进度条遍历所有CIF文件
for filename in tqdm(cif_files, desc="Processing CIFs"):
file_path = os.path.join(cif_folder_path, filename)
try:
# 1. 从CIF文件加载结构
struct = Structure.from_file(file_path)
# 2. 调用您的 CS_catulate 函数计算详细的共享关系
# 这里使用默认参数 sp='Li', anion=['O']
sharing_details = CS_catulate(struct, sp='Li', anion=['O','S','Cl','F','Br'])
# 3. 调用 check_only_corner_sharing 函数进行最终判断
is_only_corner = check_only_corner_sharing(sharing_details)
# 4. 将文件名和结果存入列表
results.append([filename, is_only_corner])
except Exception as e:
# 如果处理某个文件时出错,打印错误信息并继续处理下一个文件
print(f"\n处理文件 '{filename}' 时发生错误: {e}")
results.append([filename, 'Error']) # 在CSV中标记错误
# 5. 将结果写入CSV文件
try:
with open(output_csv_path, 'w', newline='', encoding='utf-8') as csvfile:
writer = csv.writer(csvfile)
# 写入表头
writer.writerow(['CIF_File', 'Is_Only_Corner_Sharing'])
# 写入所有结果
writer.writerows(results)
print(f"\n处理完成!结果已保存到 '{output_csv_path}'")
except IOError as e:
print(f"\n错误: 无法写入CSV文件 '{output_csv_path}': {e}")
# --- 主程序入口 ---
if __name__ == "__main__":
# ----- 参数配置 -----
# 请将此路径修改为您存放CIF文件的文件夹的实际路径
CIF_DIRECTORY = "data/0921"
# 输出的CSV文件名
OUTPUT_CSV = "corner_sharing_results.csv"
# -------------------
# 调用主函数开始处理
process_cif_folder(CIF_DIRECTORY, OUTPUT_CSV)

View File

@@ -1,3 +1,5 @@
from typing import List, Dict
from pymatgen.core.structure import Structure
from pymatgen.analysis.local_env import VoronoiNN
import numpy as np
@@ -24,7 +26,132 @@ def special_check_for_3(site, nearest):
return real_nearest
def CS_catulate(struct, sp='Li', anion=['O'], tol=0, cutoff=3.0,notice=False,ID=None):
def CS_catulate(
struct,
sp: str = 'Li',
anion: List[str] = ['O'],
tol: float = 0,
cutoff: float = 3.0,
notice: bool = False
) -> Dict[str, Dict[str, int]]:
"""
计算结构中不同类型阳离子多面体之间的共享关系(角、边、面共享)。
该函数会分别计算以下三种情况的共享数量:
1. 目标原子 vs 目标原子 (e.g., Li-Li)
2. 目标原子 vs 其他阳离子 (e.g., Li-X)
3. 其他阳离子 vs 其他阳离子 (e.g., X-Y)
参数:
struct (Structure): 输入的pymatgen结构对象。
sp (str): 目标元素符号,默认为 'Li'
anion (list): 阴离子元素符号列表,默认为 ['O']。
tol (float): VoronoiNN 的容差。对于Li通常设为0。
cutoff (float): VoronoiNN 的截断距离。对于Li通常设为3.0。
notice (bool): 是否打印详细的共享信息。
返回:
dict: 一个字典,包含三类共享关系的统计结果。
"sp_vs_sp", "sp_vs_other", "other_vs_other" 分别对应上述三种情况。
每个键的值是另一个字典统计了共享2个(边)、3个(面)等情况的数量。
例如: {'sp_vs_sp': {'1': 10, '2': 4}, 'sp_vs_other': ...}
共享1个阴离子为角共享2个为边共享3个为面共享。
"""
# 初始化 VoronoiNN 对象
voro_nn = VoronoiNN(tol=tol, cutoff=cutoff)
# 1. 分类存储所有阳离子的近邻阴离子信息
target_sites_info = []
other_cation_sites_info = []
for index, site in enumerate(struct.sites):
# 跳过阴离子本身
if site.species.chemical_system in anion:
continue
# 获取当前位点的近邻阴离子
try:
# 使用 get_nn_info 更直接
nn_info = voro_nn.get_nn_info(struct, index)
nearest_anions = [
nn["site"] for nn in nn_info
if nn["site"].species.chemical_system in anion
]
except Exception as e:
print(f"Warning: Could not get neighbors for site {index} ({site.species_string}): {e}")
continue
if not nearest_anions:
continue
# 整理信息
site_info = {
'index': index,
'element': site.species.chemical_system,
'nearest_anion_indices': {nn.index for nn in nearest_anions}
}
# 根据是否为目标原子进行分类
if site.species.chemical_system == sp:
target_sites_info.append(site_info)
else:
other_cation_sites_info.append(site_info)
# 2. 初始化结果字典
# 共享数量key: 1-角, 2-边, 3-面
results = {
"sp_vs_sp": {"1": 0, "2": 0, "3": 0, "4": 0},
"sp_vs_other": {"1": 0, "2": 0, "3": 0, "4": 0},
"other_vs_other": {"1": 0, "2": 0, "3": 0, "4": 0},
}
# 3. 计算不同类别之间的共享关系
# 3.1 目标原子 vs 目标原子 (sp_vs_sp)
for i in range(len(target_sites_info)):
for j in range(i + 1, len(target_sites_info)):
atom_i = target_sites_info[i]
atom_j = target_sites_info[j]
shared_anions = atom_i['nearest_anion_indices'].intersection(atom_j['nearest_anion_indices'])
shared_count = len(shared_anions)
if shared_count > 0 and str(shared_count) in results["sp_vs_sp"]:
results["sp_vs_sp"][str(shared_count)] += 1
if notice:
print(
f"[Li-Li] Atom {atom_i['index']} and {atom_j['index']} share {shared_count} anions: {shared_anions}")
# 3.2 目标原子 vs 其他阳离子 (sp_vs_other)
for atom_sp in target_sites_info:
for atom_other in other_cation_sites_info:
shared_anions = atom_sp['nearest_anion_indices'].intersection(atom_other['nearest_anion_indices'])
shared_count = len(shared_anions)
if shared_count > 0 and str(shared_count) in results["sp_vs_other"]:
results["sp_vs_other"][str(shared_count)] += 1
if notice:
print(
f"[Li-Other] Atom {atom_sp['index']} and {atom_other['index']} share {shared_count} anions: {shared_anions}")
# 3.3 其他阳离子 vs 其他阳离子 (other_vs_other)
for i in range(len(other_cation_sites_info)):
for j in range(i + 1, len(other_cation_sites_info)):
atom_i = other_cation_sites_info[i]
atom_j = other_cation_sites_info[j]
shared_anions = atom_i['nearest_anion_indices'].intersection(atom_j['nearest_anion_indices'])
shared_count = len(shared_anions)
if shared_count > 0 and str(shared_count) in results["other_vs_other"]:
results["other_vs_other"][str(shared_count)] += 1
if notice:
print(
f"[Other-Other] Atom {atom_i['index']} and {atom_j['index']} share {shared_count} anions: {shared_anions}")
return results
def CS_catulate_old(struct, sp='Li', anion=['O'], tol=0, cutoff=3.0,notice=False,ID=None):
"""
计算结构中目标元素与最近阴离子的共享关系。
@@ -51,10 +178,10 @@ def CS_catulate(struct, sp='Li', anion=['O'], tol=0, cutoff=3.0,notice=False,ID=
# 遍历结构中的每个位点
for index,site in enumerate(struct.sites):
# 跳过阴离子位点
if site.specie.symbol in anion:
if site.species.chemical_system in anion:
continue
# 跳过Li原子
if site.specie.symbol == sp:
if site.species.chemical_system == sp:
continue
# 获取 Voronoi 多面体信息
voro_info = voro_nn.get_voronoi_polyhedra(struct, index)
@@ -62,14 +189,14 @@ def CS_catulate(struct, sp='Li', anion=['O'], tol=0, cutoff=3.0,notice=False,ID=
# 找到最近的阴离子位点
nearest_anions = [
nn_info["site"] for nn_info in voro_info.values()
if nn_info["site"].specie.symbol in anion
if nn_info["site"].species.chemical_system in anion
]
# 如果没有找到最近的阴离子,跳过
if not nearest_anions:
print(f"No nearest anions found for {ID} site {index}.")
continue
if site.specie.symbol == 'B' or site.specie.symbol == 'N':
if site.species.chemical_system == 'B' or site.species.chemical_system == 'N':
nearest_anions = special_check_for_3(site,nearest_anions)
nearest_anions = check_real(nearest_anions)
# 将结果添加到 atom_dice 列表中
@@ -110,10 +237,62 @@ def CS_catulate(struct, sp='Li', anion=['O'], tol=0, cutoff=3.0,notice=False,ID=
return shared_count
def CS_count(struct, shared_count, sp='Li'):
def CS_count(struct, sharing_results: Dict[str, Dict[str, int]], sp: str = 'Li') -> float:
"""
分析多面体共享结果,计算平均每个目标原子参与的共享阴离子数。
这个函数是 calculate_polyhedra_sharing 的配套函数。
参数:
struct (Structure): 输入的pymatgen结构对象用于统计目标原子总数。
sharing_results (dict): 来自 calculate_polyhedra_sharing 函数的输出结果。
sp (str): 目标元素符号,默认为 'Li'
返回:
float: 平均每个目标原子sp参与的共享阴离子数量。
例如结果为2.5意味着平均每个Li原子通过共享与其他阳离子
包括Li和其他阳离子连接了2.5个阴离子。
"""
# 1. 统计结构中目标原子的总数
target_atom_count = 0
for site in struct.sites:
if site.species.chemical_system == sp:
target_atom_count += 1
# 如果结构中没有目标原子直接返回0避免除以零错误
if target_atom_count == 0:
return 0.0
# 2. 计算加权的共享阴离子总数
total_shared_anions = 0
# 处理 sp_vs_sp (例如 Li-Li) 的共享
# 每个共享关系涉及两个目标原子,所以权重需要乘以 2
if "sp_vs_sp" in sharing_results:
sp_vs_sp_counts = sharing_results["sp_vs_sp"]
for num_shared_str, count in sp_vs_sp_counts.items():
num_shared = int(num_shared_str)
# 权重 = 共享阴离子数 * 涉及的目标原子数 (2) * 出现次数
total_shared_anions += num_shared * 2 * count
# 处理 sp_vs_other (例如 Li-X) 的共享
# 每个共享关系涉及一个目标原子,所以权重乘以 1
if "sp_vs_other" in sharing_results:
sp_vs_other_counts = sharing_results["sp_vs_other"]
for num_shared_str, count in sp_vs_other_counts.items():
num_shared = int(num_shared_str)
# 权重 = 共享阴离子数 * 涉及的目标原子数 (1) * 出现次数
total_shared_anions += num_shared * 1 * count
# 3. 计算平均值
# 平均每个目标原子参与的共享阴离子数 = 总的加权共享数 / 目标原子总数
average_sharing_per_atom = total_shared_anions / target_atom_count
return average_sharing_per_atom
def CS_count_old(struct, shared_count, sp='Li'):
count = 0
for site in struct.sites:
if site.specie.symbol == sp:
if site.species.chemical_system == sp:
count += 1 # 累加符合条件的原子数量
CS_count = 0
@@ -128,7 +307,50 @@ def CS_count(struct, shared_count, sp='Li'):
return CS_count
structure = Structure.from_file("../data/0921/wjy_001.cif")
a = CS_catulate(structure,notice=True)
b = CS_count(structure,a)
print(f"{a}\n{b}")
def check_only_corner_sharing(sharing_results: Dict[str, Dict[str, int]]) -> int:
"""
检查目标原子(sp)是否只参与了角共享共享1个阴离子
该函数是 calculate_polyhedra_sharing 的配套函数。
参数:
sharing_results (dict): 来自 calculate_polyhedra_sharing 函数的输出结果。
返回:
int:
- 1: 如果 sp 的共享关系中,边共享(2)、面共享(3)等数量均为0
并且至少存在一个角共享(1)。
- 0: 如果 sp 存在任何边、面等共享,或者没有任何共享关系。
"""
# 提取与目标原子 sp 相关的共享数据
sp_vs_sp_counts = sharing_results.get("sp_vs_sp", {})
sp_vs_other_counts = sharing_results.get("sp_vs_other", {})
# 1. 检查是否存在任何边共享、面共享等 (共享数 > 1)
# 检查 sp-sp 的共享
for num_shared_str, count in sp_vs_sp_counts.items():
if int(num_shared_str) > 1 and count > 0:
return 0 # 发现了边/面共享,立即返回 0
# 检查 sp-other 的共享
for num_shared_str, count in sp_vs_other_counts.items():
if int(num_shared_str) > 1 and count > 0:
return 0 # 发现了边/面共享,立即返回 0
# 2. 检查是否存在至少一个角共享 (共享数 == 1)
# 运行到这里,说明已经没有任何边/面共享了。
# 现在需要确认是否真的存在角共享,而不是完全没有共享。
corner_share_sp_sp = sp_vs_sp_counts.get("1", 0) > 0
corner_share_sp_other = sp_vs_other_counts.get("1", 0) > 0
if corner_share_sp_sp or corner_share_sp_other:
return 1 # 确认只存在角共享
else:
return 0 # 没有任何共享关系,也返回 0
# structure = Structure.from_file("../data/0921/wjy_001.cif")
# a = CS_catulate(structure,notice=True)
# b = CS_count(structure,a)
# print(f"{a}\n{b}")
# print(check_only_corner_sharing(a))

View File

@@ -0,0 +1,50 @@
#==============================================================================
# CRYSTAL DATA
#------------------------------------------------------------------------------
# Created from data in Table S9 for ball milled and 1h annealed Li3YCl6
# Source: Image provided by user
#==============================================================================
data_Li3YCl6_annealed
_audit_creation_method 'Generated from published data table'
#------------------------------------------------------------------------------
# CELL PARAMETERS
#------------------------------------------------------------------------------
_cell_length_a 11.2001(2)
_cell_length_b 11.2001(2)
_cell_length_c 6.0441(2)
_cell_angle_alpha 90.0
_cell_angle_beta 90.0
_cell_angle_gamma 120.0
_cell_volume 656.39
#------------------------------------------------------------------------------
# SYMMETRY
#------------------------------------------------------------------------------
_symmetry_space_group_name_H-M 'P -3 m 1'
_symmetry_Int_Tables_number 164
#------------------------------------------------------------------------------
# ATOMIC COORDINATES AND DISPLACEMENT PARAMETERS
#------------------------------------------------------------------------------
loop_
_atom_site_label
_atom_site_type_symbol
_atom_site_Wyckoff_symbol
_atom_site_fract_x
_atom_site_fract_y
_atom_site_fract_z
_atom_site_occupancy
_atom_site_B_iso_or_equiv
Y1 Y 1a 0.0000 0.0000 0.0000 1.0000 1.10(3)
Y2 Y 2d 0.3333 0.6666 -0.065(3) 1.0000 1.10(3)
Cl1 Cl 6i 0.1131(7) -0.1131(7) 0.7717(8) 1.0000 1.99(3)
Cl2 Cl 6i 0.2182(7) -0.2182(7) 0.2606(8) 1.0000 1.99(3)
Cl3 Cl 6i 0.4436(4) -0.4436(4) 0.7604(8) 1.0000 1.99(3)
Li1 Li 6g 0.3397 0.3397 0.0000 1.0000 5.00
Li2 Li 6h 0.3397 0.3397 0.5000 0.5000 5.00
#==============================================================================
# END OF DATA
#==============================================================================

View File

@@ -0,0 +1,50 @@
#==============================================================================
# CRYSTAL DATA
#------------------------------------------------------------------------------
# Created from data in Table S9 for ball milled and 1h annealed Li3YCl6
# Source: Image provided by user
#==============================================================================
data_Li3YCl6_annealed
_audit_creation_method 'Generated from published data table'
#------------------------------------------------------------------------------
# CELL PARAMETERS
#------------------------------------------------------------------------------
_cell_length_a 11.2001(2)
_cell_length_b 11.2001(2)
_cell_length_c 6.0441(2)
_cell_angle_alpha 90.0
_cell_angle_beta 90.0
_cell_angle_gamma 120.0
_cell_volume 656.39
#------------------------------------------------------------------------------
# SYMMETRY
#------------------------------------------------------------------------------
_symmetry_space_group_name_H-M 'P -3 m 1'
_symmetry_Int_Tables_number 164
#------------------------------------------------------------------------------
# ATOMIC COORDINATES AND DISPLACEMENT PARAMETERS
#------------------------------------------------------------------------------
loop_
_atom_site_label
_atom_site_type_symbol
_atom_site_Wyckoff_symbol
_atom_site_fract_x
_atom_site_fract_y
_atom_site_fract_z
_atom_site_occupancy
_atom_site_B_iso_or_equiv
Y1 Y 1a 0.0000 0.0000 0.0000 1.0000 1.10(3)
Y2 Y 2d 0.3333 0.6666 0.488(1) 1.0000 1.10(3)
Cl1 Cl 6i 0.1131(7) -0.1131(7) 0.7717(8) 1.0000 1.99(3)
Cl2 Cl 6i 0.2182(7) -0.2182(7) 0.2606(8) 1.0000 1.99(3)
Cl3 Cl 6i 0.4436(4) -0.4436(4) 0.7604(8) 1.0000 1.99(3)
Li1 Li 6g 0.3397 0.3397 0.0000 1.0000 5.00
Li2 Li 6h 0.3397 0.3397 0.5000 0.5000 5.00
#==============================================================================
# END OF DATA
#==============================================================================

View File

@@ -0,0 +1,51 @@
#==============================================================================
# CRYSTAL DATA
#------------------------------------------------------------------------------
# Created from data in Table S9 for ball milled and 1h annealed Li3YCl6
# Source: Image provided by user
#==============================================================================
data_Li3YCl6_annealed
_audit_creation_method 'Generated from published data table'
#------------------------------------------------------------------------------
# CELL PARAMETERS
#------------------------------------------------------------------------------
_cell_length_a 11.2001(2)
_cell_length_b 11.2001(2)
_cell_length_c 6.0441(2)
_cell_angle_alpha 90.0
_cell_angle_beta 90.0
_cell_angle_gamma 120.0
_cell_volume 656.39
#------------------------------------------------------------------------------
# SYMMETRY
#------------------------------------------------------------------------------
_symmetry_space_group_name_H-M 'P -3 m 1'
_symmetry_Int_Tables_number 164
#------------------------------------------------------------------------------
# ATOMIC COORDINATES AND DISPLACEMENT PARAMETERS
#------------------------------------------------------------------------------
loop_
_atom_site_label
_atom_site_type_symbol
_atom_site_Wyckoff_symbol
_atom_site_fract_x
_atom_site_fract_y
_atom_site_fract_z
_atom_site_occupancy
_atom_site_B_iso_or_equiv
Y1 Y 1a 0.0000 0.0000 0.0000 1.0000 1.10(3)
Y2 Y 2d 0.3333 0.6666 0.488(1) 0.7500 1.10(3)
Y3 Y 2d 0.3333 0.6666 -0.065(3) 0.2500 1.10(3)
Cl1 Cl 6i 0.1131(7) -0.1131(7) 0.7717(8) 1.0000 1.99(3)
Cl2 Cl 6i 0.2182(7) -0.2182(7) 0.2606(8) 1.0000 1.99(3)
Cl3 Cl 6i 0.4436(4) -0.4436(4) 0.7604(8) 1.0000 1.99(3)
Li1 Li 6g 0.3397 0.3397 0.0000 1.0000 5.00
Li2 Li 6h 0.3397 0.3397 0.5000 0.5000 5.00
#==============================================================================
# END OF DATA
#==============================================================================

View File

@@ -0,0 +1,60 @@
#==============================================================================
# CRYSTAL DATA
#------------------------------------------------------------------------------
# Created from data in Table S9 for ball milled and 1h annealed Li3YCl6
# Source: Image provided by user
#==============================================================================
data_Li3YCl6_annealed
_audit_creation_method 'Generated from published data table'
#------------------------------------------------------------------------------
# CELL PARAMETERS
#------------------------------------------------------------------------------
_cell_length_a 11.2001(2)
_cell_length_b 11.2001(2)
_cell_length_c 6.0441(2)
_cell_angle_alpha 90.0
_cell_angle_beta 90.0
_cell_angle_gamma 120.0
_cell_volume 656.39
#------------------------------------------------------------------------------
# SYMMETRY
#------------------------------------------------------------------------------
_symmetry_space_group_name_H-M 'P -3 m 1'
_symmetry_Int_Tables_number 164
#------------------------------------------------------------------------------
# ATOMIC COORDINATES AND DISPLACEMENT PARAMETERS
#------------------------------------------------------------------------------
loop_
_atom_site_label
_atom_site_type_symbol
_atom_site_fract_x
_atom_site_fract_y
_atom_site_fract_z
_atom_site_occupancy
# --- Y位点 (基于Model 3, Occ: M1=1, M2=0.75, M3=0.25) ---
Y1_main Y 0.0000 0.0000 0.0000 0.9444 # 1a位
Li_on_Y1 Li 0.0000 0.0000 0.0000 0.0556
Y2_main Y 0.3333 0.6666 0.4880 0.7083 # 2d位
Li_on_Y2 Li 0.3333 0.6666 0.4880 0.0417
Y3_main Y 0.3333 0.6666 -0.0650 0.2361 # 2d位
Li_on_Y3 Li 0.3333 0.6666 -0.0650 0.0139
# --- Li位点 (基于标准结构, Occ: 6g=1, 6h=0.5) ---
Li1_main Li 0.3397 0.3397 0.0000 0.9815 # 6g位
Y_on_Li1 Y 0.3397 0.3397 0.0000 0.0185
Li2_main Li 0.3397 0.3397 0.5000 0.4907 # 6h位
Y_on_Li2 Y 0.3397 0.3397 0.5000 0.0093
# --- Cl位点 (保持不变) ---
Cl1 Cl 0.1131 -0.1131 0.7717 1.0000
Cl2 Cl 0.2182 -0.2182 0.2606 1.0000
Cl3 Cl 0.4436 -0.4436 0.7604 1.0000
#==============================================================================
# END OF DATA
#==============================================================================

View File

@@ -0,0 +1,51 @@
#==============================================================================
# CRYSTAL DATA
#------------------------------------------------------------------------------
# Created from data in Table S9 for ball milled and 1h annealed Li3YCl6
# Source: Image provided by user
#==============================================================================
data_Li3YCl6_annealed
_audit_creation_method 'Generated from published data table'
#------------------------------------------------------------------------------
# CELL PARAMETERS
#------------------------------------------------------------------------------
_cell_length_a 11.2001(2)
_cell_length_b 11.2001(2)
_cell_length_c 6.0441(2)
_cell_angle_alpha 90.0
_cell_angle_beta 90.0
_cell_angle_gamma 120.0
_cell_volume 656.39
#------------------------------------------------------------------------------
# SYMMETRY
#------------------------------------------------------------------------------
_symmetry_space_group_name_H-M 'P -3 m 1'
_symmetry_Int_Tables_number 164
#------------------------------------------------------------------------------
# ATOMIC COORDINATES AND DISPLACEMENT PARAMETERS
#------------------------------------------------------------------------------
loop_
_atom_site_label
_atom_site_type_symbol
_atom_site_Wyckoff_symbol
_atom_site_fract_x
_atom_site_fract_y
_atom_site_fract_z
_atom_site_occupancy
_atom_site_B_iso_or_equiv
Y1 Y 1a 0.0000 0.0000 0.0000 1.0000 1.10(3)
Y2 Y 2d 0.3333 0.6666 0.488(1) 0.8269 1.10(3)
Y3 Y 2d 0.3333 0.6666 -0.065(3) 0.1730 1.10(3)
Cl1 Cl 6i 0.1131(7) -0.1131(7) 0.7717(8) 1.0000 1.99(3)
Cl2 Cl 6i 0.2182(7) -0.2182(7) 0.2606(8) 1.0000 1.99(3)
Cl3 Cl 6i 0.4436(4) -0.4436(4) 0.7604(8) 1.0000 1.99(3)
Li1 Li 6g 0.3397 0.3397 0.0000 1.0000 5.00
Li2 Li 6h 0.3397 0.3397 0.5000 0.5000 5.00
#==============================================================================
# END OF DATA
#==============================================================================

View File

@@ -0,0 +1,53 @@
#------------------------------------------------------------------------------
# CIF (Crystallographic Information File) for Li3YCl6
# Data source: Table S1 from the provided image.
# Rietveld refinement result of the neutron diffraction pattern for the 450 °C-annealed sample.
#------------------------------------------------------------------------------
data_Li3YCl6
_chemical_name_systematic 'Lithium Yttrium Chloride'
_chemical_formula_sum 'Li3 Y1 Cl6'
_chemical_formula_structural 'Li3YCl6'
_symmetry_space_group_name_H-M 'P n m a'
_symmetry_Int_Tables_number 62
_symmetry_cell_setting orthorhombic
loop_
_symmetry_equiv_pos_as_xyz
'x, y, z'
'-x+1/2, y+1/2, -z+1/2'
'-x, -y, -z'
'x+1/2, -y+1/2, z+1/2'
'-x, y+1/2, -z'
'x-1/2, -y-1/2, z-1/2'
'x, -y, z'
'-x-1/2, y-1/2, -z-1/2'
_cell_length_a 12.92765(13)
_cell_length_b 11.19444(10)
_cell_length_c 6.04000(12)
_cell_angle_alpha 90.0
_cell_angle_beta 90.0
_cell_angle_gamma 90.0
_cell_volume 874.15
_cell_formula_units_Z 4
loop_
_atom_site_label
_atom_site_type_symbol
_atom_site_fract_x
_atom_site_fract_y
_atom_site_fract_z
_atom_site_occupancy
_atom_site_Wyckoff_symbol
_atom_site_U_iso_or_equiv
Li1 Li 0.11730(7) 0.09640(7) 0.04860(10) 0.750(13) 8d 4.579(2)
Li2 Li 0.13270(9) 0.07900(10) 0.48600(2) 0.750(19) 8d 9.554(4)
Cl1 Cl 0.21726(7) 0.58920(7) 0.26362(11) 1.0 8d 0.797(17)
Cl2 Cl 0.45948(8) 0.08259(8) 0.23831(13) 1.0 8d 1.548(2)
Cl3 Cl 0.04505(10) 0.25000 0.74110(2) 1.0 4c 1.848(3)
Cl4 Cl 0.20205(9) 0.25000 0.24970(2) 1.0 4c 0.561(2)
Y1 Y 0.37529(10) 0.25000 0.01870(3) 1.0 4c 1.121(17)
#------------------------------------------------------------------------------

151
dpgen/plus.py Normal file
View File

@@ -0,0 +1,151 @@
import random
from typing import List
from pymatgen.core import Structure
from pymatgen.io.vasp import Poscar
def _is_close_frac(z, target, tol=2e-2):
t = target % 1.0
return min(abs(z - t), abs(z - (t + 1)), abs(z - (t - 1))) < tol
def make_model3_poscar_from_cif(cif_path: str,
out_poscar: str = "POSCAR_model3_supercell",
seed: int = 42,
tol: float = 2e-2):
"""
将 model3.cif 扩胞为 [[3,0,0],[2,4,0],[0,0,6]] 的2160原子超胞并把部分占据位点(Y2=0.75, Y3=0.25, Li2=0.5)
显式有序化后写出 POSCAR。
"""
random.seed(seed)
# 1) 读取 CIF
s = Structure.from_file(cif_path)
# 2) 扩胞a_s=3a0, b_s=2a0+4b0, c_s=6c0[1]
T = [[3, 0, 0],
[2, 4, 0],
[0, 0, 6]]
s.make_supercell(T)
# 3) 识别三类需取整的位点Y2、Y3、Li2
y2_idx: List[int] = []
y3_idx: List[int] = []
li2_idx: List[int] = []
for i, site in enumerate(s.sites):
# 兼容不同版本pymatgen
try:
el = site.species.elements[0].symbol
except Exception:
ss = site.species_string
el = "Li" if ss.startswith("Li") else ("Y" if ss.startswith("Y") else ("Cl" if ss.startswith("Cl") else ss))
z = site.frac_coords[2]
if el == "Y":
if _is_close_frac(z, 0.488, tol):
y2_idx.append(i)
elif _is_close_frac(z, -0.065, tol) or _is_close_frac(z, 0.935, tol):
y3_idx.append(i)
elif el == "Li":
if _is_close_frac(z, 0.5, tol):
li2_idx.append(i)
def choose_keep(idxs, frac_keep):
n = len(idxs)
k = int(round(n * frac_keep))
if k < 0: k = 0
if k > n: k = n
keep = set(random.sample(idxs, k)) if 0 < k < n else set(idxs if k == n else [])
drop = [i for i in idxs if i not in keep]
return keep, drop
keep_y2, drop_y2 = choose_keep(y2_idx, 0.75)
keep_y3, drop_y3 = choose_keep(y3_idx, 0.25)
keep_li2, drop_li2 = choose_keep(li2_idx, 0.50)
# 4) 保留者占据设为1其余删除
for i in keep_y2 | keep_y3:
s.replace(i, "Y")
for i in keep_li2:
s.replace(i, "Li")
to_remove = sorted(drop_y2 + drop_y3 + drop_li2, reverse=True)
for i in to_remove:
s.remove_sites([i])
# 5) 最终清理:消除任何残留的部分占据(防止 POSCAR 写出报错)
# 若有 site.is_ordered==False则取该站位的“主要元素”替换为占据=1
for i, site in enumerate(s.sites):
if not site.is_ordered:
d = site.species.as_dict() # {'Li': 0.5} 或 {'Li':0.5,'Y':0.5}
elem = max(d.items(), key=lambda kv: kv[1])[0]
s.replace(i, elem)
# 6) 排序并写出 POSCAR
order = {"Li": 0, "Y": 1, "Cl": 2}
s = s.get_sorted_structure(key=lambda site: order.get(site.species.elements[0].symbol, 99))
Poscar(s).write_file(out_poscar)
# 报告
comp = {k: int(v) for k, v in s.composition.as_dict().items()}
print(f"写出 {out_poscar};总原子数 = {len(s)}")
print(f"Y2识别={len(y2_idx)}Y3识别={len(y3_idx)}Li2识别={len(li2_idx)};组成={comp}")
import random
from typing import List
from pymatgen.core import Structure
from pymatgen.io.vasp import Poscar
def make_pnma_poscar_from_cif(cif_path: str,
out_poscar: str = "POSCAR_pnma_supercell",
seed: int = 42,
supercell=(3,3,6),
tol: float = 1e-6):
"""
读取 Pnma 的 CIF如 origin.cif扩胞到 2160 原子,并把部分占据的 Li 位点(0.75)显式取整后写出 POSCAR。
默认超胞尺度为(3,3,6),体积放大因子=5440原子/原胞×54=2160 [1][3]。
"""
random.seed(seed)
s = Structure.from_file(cif_path)
# 扩胞Pnma原胞已是正交直接用对角放缩
s.make_supercell(supercell)
# 找出所有“部分占据的 Li”位点
partial_li_idx: List[int] = []
for i, site in enumerate(s.sites):
if not site.is_ordered:
d = site.species.as_dict() # 例如 {'Li': 0.75}
# 只处理主要元素是Li且占据<1的位点
m_elem, m_occ = max(d.items(), key=lambda kv: kv[1])
if m_elem == "Li" and m_occ < 1 - tol:
partial_li_idx.append(i)
# 以占据0.75进行随机取整保留75%,其余删除为“空位”
n = len(partial_li_idx)
k = int(round(n * 0.75))
keep = set(random.sample(partial_li_idx, k)) if 0 < k < n else set(partial_li_idx if k == n else [])
drop = sorted([i for i in partial_li_idx if i not in keep], reverse=True)
# 保留者设为占据=1删除其余
for i in keep:
s.replace(i, "Li")
for i in drop:
s.remove_sites([i])
# 兜底:若仍有部分占据,强制取主要元素
for i, site in enumerate(s.sites):
if not site.is_ordered:
d = site.species.as_dict()
elem = max(d.items(), key=lambda kv: kv[1])[0]
s.replace(i, elem)
# 排序并写POSCAR
order = {"Li": 0, "Y": 1, "Cl": 2}
s = s.get_sorted_structure(key=lambda site: order.get(site.species.elements[0].symbol, 99))
Poscar(s).write_file(out_poscar)
comp = {k: int(v) for k, v in s.composition.as_dict().items()}
print(f"写出 {out_poscar};总原子数 = {len(s)};组成 = {comp}")
if __name__=="__main__":
# make_model3_poscar_from_cif("data/P3ma/model3.cif","data/P3ma/supercell_model4.poscar")
make_pnma_poscar_from_cif("data/Pnma/origin.cif","data/Pnma/supercell_pnma.poscar",seed=42)

View File

76
dpgen/transport.py Normal file
View File

@@ -0,0 +1,76 @@
from typing import List, Dict, Tuple
def add_li_y_antisite_average(
x: float,
# Y 位点:[(label, wyckoff, (fx,fy,fz), occ, multiplicity)]
y_sites: List[Tuple[str, str, Tuple[float, float, float], float, int]] = None,
# Li 位点:[(label, wyckoff, (fx,fy,fz), occ, multiplicity)]
li_sites: List[Tuple[str, str, Tuple[float, float, float], float, int]] = None,
# 可选:各位点的 B 因子(字典),不需要可留空
b_iso: Dict[str, float] = None,
# 保留小数位
ndigits: int = 4
) -> str:
"""
把结构改写成含 x (相对于 Y 总数) 的 Li/Y 反位缺陷的“平均占据”CIF 块。
返回字符串,可直接粘贴到 CIF 的 atom_site loop 中。
"""
# 默认使用表 S9P-3̅m1的坐标与占据multiplicity 依据该空间群
if y_sites is None:
y_sites = [
("Y1", "1a", (0.0000, 0.0000, 0.0000), 1.0000, 1),
("Y2", "2d", (0.3333, 0.6666, 0.4880), 0.8269, 2),
("Y3", "2d", (0.3333, 0.6666,-0.0650), 0.1730, 2),
]
if li_sites is None:
li_sites = [
("Li1", "6g", (0.3397, 0.3397, 0.0000), 1.0000, 6),
("Li2", "6h", (0.3397, 0.3397, 0.5000), 0.5000, 6),
]
if b_iso is None:
# 可按需要改
b_iso = {"Y": 1.10, "Li": 5.00, "Cl": 1.99}
# 1) 总 Y 与总 Li每原胞的“原子数贡献”
y_total = sum(m * occ for _,_,_,occ,m in y_sites) # 期望为 3
li_total = sum(m * occ for _,_,_,occ,m in li_sites) # 期望为 9
# 2) 需要交换的“Y 数量”(相对于每原胞)
y_exchanged = x * y_total # 例如 x=0.0556 → ~0.1668 个 Y/原胞
# 3) 折算到 Li 子格的分配比例(保证电中性和计量)
li_fraction = y_exchanged / li_total # 每个 Li 位点应引入的 Y 的比例
# 4) 生成 CIF loop 文本(同一坐标写两行,分别为主元素与反位元素)
lines = []
header = [
"loop_",
" _atom_site_label",
" _atom_site_type_symbol",
" _atom_site_fract_x",
" _atom_site_fract_y",
" _atom_site_fract_z",
" _atom_site_occupancy",
" _atom_site_B_iso_or_equiv",
]
lines.extend(header)
# 4a) 处理 Y 位点Y → (1-x)*occLi_on_Y → x*occ
for label, wyck, (fx,fy,fz), occ, mult in y_sites:
y_main_occ = round(occ * (1.0 - x), ndigits)
li_on_y_occ = round(occ * x, ndigits)
lines.append(f" {label}_main Y {fx:.4f} {fy:.4f} {fz:.4f} {y_main_occ:.{ndigits}f} {b_iso.get('Y',1.0)}")
lines.append(f" Li_on_{label} Li {fx:.4f} {fy:.4f} {fz:.4f} {li_on_y_occ:.{ndigits}f} {b_iso.get('Li',5.0)}")
# 4b) 处理 Li 位点Li → (1-li_fraction)*occY_on_Li → li_fraction*occ
for label, wyck, (fx,fy,fz), occ, mult in li_sites:
li_main_occ = round(occ * (1.0 - li_fraction), ndigits)
y_on_li_occ = round(occ * li_fraction, ndigits)
lines.append(f" {label}_main Li {fx:.4f} {fy:.4f} {fz:.4f} {li_main_occ:.{ndigits}f} {b_iso.get('Li',5.0)}")
lines.append(f" Y_on_{label} Y {fx:.4f} {fy:.4f} {fz:.4f} {y_on_li_occ:.{ndigits}f} {b_iso.get('Y',1.0)}")
return "\n".join(lines)
# === 示例 ===
# x=0.0556(与论文使用的 5.56% 一致,定义为相对于 Y 总数的比例):
print(add_li_y_antisite_average(0.0556))

View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2025 gyj155
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -0,0 +1,96 @@
# Paper Semantic Search
Find similar papers using semantic search. Supports both local models (free) and OpenAI API (better quality).
## Features
- Request for papers from OpenReview (e.g., ICLR2026 submissions)
- Semantic search with example papers or text queries
- Support embedding caching
- Embed model support: Open-source (e.g., all-MiniLM-L6-v2) or OpenAI
## Quick Start
```bash
pip install -r requirements.txt
```
### 1. Prepare Papers
```python
from crawl import crawl_papers
crawl_papers(
venue_id="ICLR.cc/2026/Conference/Submission",
output_file="iclr2026_papers.json"
)
```
### 2. Search Papers
```python
from search import PaperSearcher
# Local model (free)
searcher = PaperSearcher('iclr2026_papers.json', model_type='local')
# OpenAI model (better, requires API key)
# export OPENAI_API_KEY='your-key'
# searcher = PaperSearcher('iclr2026_papers.json', model_type='openai')
searcher.compute_embeddings()
# Search with example papers that you are interested in
examples = [
{
"title": "Your paper title",
"abstract": "Your paper abstract..."
}
]
results = searcher.search(examples=examples, top_k=100)
# Or search with text query
results = searcher.search(query="interesting topics", top_k=100)
searcher.display(results, n=10)
searcher.save(results, 'results.json')
```
## How It Works
1. Paper titles and abstracts are converted to embeddings
2. Embeddings are cached automatically
3. Your query is embedded using the same model
4. Cosine similarity finds the most similar papers
5. Results are ranked by similarity score
## Cache
Embeddings are cached as `cache_<filename>_<hash>_<model>.npy`. Delete to recompute.
## Example Output
```
================================================================================
Top 100 Results (showing 10)
================================================================================
1. [0.8456] Paper a
#12345 | foundation or frontier models, including LLMs
https://openreview.net/forum?id=xxx
2. [0.8234] Paper b
#12346 | applications to robotics, autonomy, planning
https://openreview.net/forum?id=yyy
```
## Tips
- Use 1-5 example papers for best results, or a paragraph of description of your interested topic
- Local model is good enough for most cases
- OpenAI model for critical search (~$1 for 18k queries)
If it's useful, please consider giving a star~

View File

@@ -0,0 +1,66 @@
import requests
import json
import time
def fetch_submissions(venue_id, offset=0, limit=1000):
url = "https://api2.openreview.net/notes"
params = {
"content.venueid": venue_id,
"details": "replyCount,invitation",
"limit": limit,
"offset": offset,
"sort": "number:desc"
}
headers = {"User-Agent": "Mozilla/5.0"}
response = requests.get(url, params=params, headers=headers)
response.raise_for_status()
return response.json()
def crawl_papers(venue_id, output_file):
all_papers = []
offset = 0
limit = 1000
print(f"Fetching papers from {venue_id}...")
while True:
data = fetch_submissions(venue_id, offset, limit)
notes = data.get("notes", [])
if not notes:
break
for note in notes:
paper = {
"id": note.get("id"),
"number": note.get("number"),
"title": note.get("content", {}).get("title", {}).get("value", ""),
"authors": note.get("content", {}).get("authors", {}).get("value", []),
"abstract": note.get("content", {}).get("abstract", {}).get("value", ""),
"keywords": note.get("content", {}).get("keywords", {}).get("value", []),
"primary_area": note.get("content", {}).get("primary_area", {}).get("value", ""),
"forum_url": f"https://openreview.net/forum?id={note.get('id')}"
}
all_papers.append(paper)
print(f"Fetched {len(notes)} papers (total: {len(all_papers)})")
if len(notes) < limit:
break
offset += limit
time.sleep(0.5)
with open(output_file, "w", encoding="utf-8") as f:
json.dump(all_papers, f, ensure_ascii=False, indent=2)
print(f"\nTotal: {len(all_papers)} papers")
print(f"Saved to {output_file}")
return all_papers
if __name__ == "__main__":
crawl_papers(
venue_id="ICLR.cc/2026/Conference/Submission",
output_file="iclr2026_papers.json"
)

View File

@@ -0,0 +1,22 @@
from search import PaperSearcher
# Use local model (free)
searcher = PaperSearcher('iclr2026_papers.json', model_type='local')
# Or use OpenAI (better quality)
# searcher = PaperSearcher('iclr2026_papers.json', model_type='openai')
searcher.compute_embeddings()
examples = [
{
"title": "Solid-State battery",
"abstract": "Solid-State battery"
},
]
results = searcher.search(examples=examples, top_k=100)
searcher.display(results, n=10)
searcher.save(results, 'results.json')

View File

@@ -0,0 +1,6 @@
requests
numpy
scikit-learn
sentence-transformers
openai

View File

@@ -0,0 +1,156 @@
import json
import numpy as np
import os
import hashlib
from pathlib import Path
from sklearn.metrics.pairwise import cosine_similarity
class PaperSearcher:
def __init__(self, papers_file, model_type="openai", api_key=None, base_url=None):
with open(papers_file, 'r', encoding='utf-8') as f:
self.papers = json.load(f)
self.model_type = model_type
self.cache_file = self._get_cache_file(papers_file, model_type)
self.embeddings = None
if model_type == "openai":
from openai import OpenAI
self.client = OpenAI(
api_key=api_key or os.getenv('OPENAI_API_KEY'),
base_url=base_url
)
self.model_name = "text-embedding-3-large"
else:
from sentence_transformers import SentenceTransformer
self.model = SentenceTransformer('all-MiniLM-L6-v2')
self.model_name = "all-MiniLM-L6-v2"
self._load_cache()
def _get_cache_file(self, papers_file, model_type):
base_name = Path(papers_file).stem
file_hash = hashlib.md5(papers_file.encode()).hexdigest()[:8]
cache_name = f"cache_{base_name}_{file_hash}_{model_type}.npy"
return str(Path(papers_file).parent / cache_name)
def _load_cache(self):
if os.path.exists(self.cache_file):
try:
self.embeddings = np.load(self.cache_file)
if len(self.embeddings) == len(self.papers):
print(f"Loaded cache: {self.embeddings.shape}")
return True
self.embeddings = None
except:
self.embeddings = None
return False
def _save_cache(self):
np.save(self.cache_file, self.embeddings)
print(f"Saved cache: {self.cache_file}")
def _create_text(self, paper):
parts = []
if paper.get('title'):
parts.append(f"Title: {paper['title']}")
if paper.get('abstract'):
parts.append(f"Abstract: {paper['abstract']}")
if paper.get('keywords'):
kw = ', '.join(paper['keywords']) if isinstance(paper['keywords'], list) else paper['keywords']
parts.append(f"Keywords: {kw}")
return ' '.join(parts)
def _embed_openai(self, texts):
if isinstance(texts, str):
texts = [texts]
embeddings = []
batch_size = 100
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
response = self.client.embeddings.create(input=batch, model=self.model_name)
embeddings.extend([item.embedding for item in response.data])
return np.array(embeddings)
def _embed_local(self, texts):
if isinstance(texts, str):
texts = [texts]
return self.model.encode(texts, show_progress_bar=len(texts) > 100)
def compute_embeddings(self, force=False):
if self.embeddings is not None and not force:
print("Using cached embeddings")
return self.embeddings
print(f"Computing embeddings ({self.model_name})...")
texts = [self._create_text(p) for p in self.papers]
if self.model_type == "openai":
self.embeddings = self._embed_openai(texts)
else:
self.embeddings = self._embed_local(texts)
print(f"Computed: {self.embeddings.shape}")
self._save_cache()
return self.embeddings
def search(self, examples=None, query=None, top_k=100):
if self.embeddings is None:
self.compute_embeddings()
if examples:
texts = []
for ex in examples:
text = f"Title: {ex['title']}"
if ex.get('abstract'):
text += f" Abstract: {ex['abstract']}"
texts.append(text)
if self.model_type == "openai":
embs = self._embed_openai(texts)
else:
embs = self._embed_local(texts)
query_emb = np.mean(embs, axis=0).reshape(1, -1)
elif query:
if self.model_type == "openai":
query_emb = self._embed_openai(query).reshape(1, -1)
else:
query_emb = self._embed_local(query).reshape(1, -1)
else:
raise ValueError("Provide either examples or query")
similarities = cosine_similarity(query_emb, self.embeddings)[0]
top_indices = np.argsort(similarities)[::-1][:top_k]
return [{
'paper': self.papers[idx],
'similarity': float(similarities[idx])
} for idx in top_indices]
def display(self, results, n=10):
print(f"\n{'='*80}")
print(f"Top {len(results)} Results (showing {min(n, len(results))})")
print(f"{'='*80}\n")
for i, result in enumerate(results[:n], 1):
paper = result['paper']
sim = result['similarity']
print(f"{i}. [{sim:.4f}] {paper['title']}")
print(f" #{paper.get('number', 'N/A')} | {paper.get('primary_area', 'N/A')}")
print(f" {paper['forum_url']}\n")
def save(self, results, output_file):
with open(output_file, 'w', encoding='utf-8') as f:
json.dump({
'model': self.model_name,
'total': len(results),
'results': results
}, f, ensure_ascii=False, indent=2)
print(f"Saved to {output_file}")

38
mcp/data/id_rsa.txt Normal file
View File

@@ -0,0 +1,38 @@
-----BEGIN OPENSSH PRIVATE KEY-----
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAABlwAAAAdzc2gtcn
NhAAAAAwEAAQAAAYEA57lv3qJ4z66QO6uxFBnd5QFTsj9P70tO7aSEbgjczT0rgg9+48Ik
S/n2m8z4s9C4bk1mTyotJc7p13nveLo0/PAO2Y/6KiSDK4HPMEr8BeWe2RdSBVgfHNls08
2eQo/DhW5pbbybKPDI8YOyhijEOF2fDD5I5bA7QUb2Ue8cOo45aPFkFPl6E2j1u9Xaua4+
oE0syDzUvMhWZJdZqeQ//Qm1+RzB2+n4y41Ym/5YsQrL6zm4RBUrgSlx4DP6sAx1dPq2OX
5XEh6888/QVA55liNukOtOumLjMXhLe5Ut8rur3FyYHmI2jkVpBgAQOErcxvH5UeRCgIh2
vUdeAzOk0STzhKon7nIrTPek/SEaM2Kdn9y4+X4tgANTZWTf5M9ELlqthRiff2fHIf++11
v/ChqblDaPzSZ+y6myemiRLVouQbrj0Kokvqrv/lL5XzpQrAHQ1PWUB1DUXB5B8W2xsTnB
2EZQ7iH4A6VSyzrJb93xTWTjIzytn17PDH0l1JS3AAAFiOMdXKvjHVyrAAAAB3NzaC1yc2
EAAAGBAOe5b96ieM+ukDursRQZ3eUBU7I/T+9LTu2khG4I3M09K4IPfuPCJEv59pvM+LPQ
uG5NZk8qLSXO6dd573i6NPzwDtmP+iokgyuBzzBK/AXlntkXUgVYHxzZbNPNnkKPw4VuaW
28myjwyPGDsoYoxDhdnww+SOWwO0FG9lHvHDqOOWjxZBT5ehNo9bvV2rmuPqBNLMg81LzI
VmSXWankP/0Jtfkcwdvp+MuNWJv+WLEKy+s5uEQVK4EpceAz+rAMdXT6tjl+VxIevPPP0F
QOeZYjbpDrTrpi4zF4S3uVLfK7q9xcmB5iNo5FaQYAEDhK3Mbx+VHkQoCIdr1HXgMzpNEk
84SqJ+5yK0z3pP0hGjNinZ/cuPl+LYADU2Vk3+TPRC5arYUYn39nxyH/vtdb/woam5Q2j8
0mfsupsnpokS1aLkG649CqJL6q7/5S+V86UKwB0NT1lAdQ1FweQfFtsbE5wdhGUO4h+AOl
Uss6yW/d8U1k4yM8rZ9ezwx9JdSUtwAAAAMBAAEAAAGAMI2mZxvb/IgzKI2dGP0ihW11wA
+MDDPXYevq47NvsIF0sFfW2po/SLwjdBnKssK1IkeNfGD1/MoSLVgbWUyK9cTHF8cXP+VO
prsYUqIjlIi8c/hy8zO3sS/NocOfuYquCTNNW/T8/eMV96UErx+znavgO4yBcb8va0oXKq
vTWmGaneaWdd6gOZjwhF8W6XkdHjGNhJdabAP+Ni2QWAy/a6GxQ3VHGXE49E21l1n/83iz
qaH6fimBaBrrBXNev6ycObPIyyXpEbwKi6GbmMbPOiR/DrhTgptpc+TJwBLd4JnX1cqCgX
sDiSOog9bgV2okznrxAINMFnrBXD12CXZfdJCsZeDWCxnVngWGImzXk6TGbfvBbyRTIkF9
qmW1BdydGrMKQoHiphndWPlJfdRl7r2ASoUkjDSK/hXV/6iYBI5ZRmZSqihFMOQUpYxLu5
nz+WecLXZYVfAlIXlESQ3PQJ33/CnDCVqpzjtsQxRYWhA4kVaCMjPnt83LAMDheWlhAAAA
wGw2bgn9Ivw8QWSPckU7+TcmemjAVQjbcBmz9aCJlBxHtZiBXa9oQOjDh8Uw84jbiX/6sb
uzn2xArZOxWPCd2ZWKyZNodyvI6sQqb4D+xHt4aReWoU5wPDaIZpkuyWzDPSZARmy2k95z
Dq995Gl8rW2xkw/f9cTHNf6wvYdvclzKrg1mCdoBUwX1diNI2l7wsww6bDfNYMZcgX82O5
aRaIJUJltQ7CXbIow9G2BqquoEjSg6/9ZZ/B0ZWyW+5uuM+wAAAMEA/Z/HZmIuFbmNKC5m
tjXCaz9x7oTXl0v+4XMx6smQqklx1XqdXe1YSdbWxJZAhfbLmiOmQIncee/+H7m42zLsFs
kgbDtze7+qLi2+MYStd75FypvQ3h+mmYq3ppkBrAiDcJ9UrG7pWUfq+FY6CyOE5ub0mmhm
w/DW/I9so8wEi1VBzi0SqpUO6snx77yZoWJhJvlhbEGBvAS/wFIX9MBBefbf5vGMhUT+pW
xUIRvizKh/gtySXrj6WPBVoak01AatAAAAwQDp5SPKHHRO/53eC+nVSDK2fc2YWEFkSLQn
MCu+pQZv5izoyYPP8FZ4y5qw+16H2f3GbPH8xCDKlokKJlKggDhDV45eWz5UbItDt43okD
uB6v9EP4AtKKUNm+GxwhwyoY/C395fe8EvgsAlXNCAy3Wt6cVmbXW+ZSv5JRV9J0GX+5F2
K+LjNm4r/1BaLyUOf0eGTvMBc3XEBIKk7MsEBVnfxHmBJQ6fpAScimEM/VrZCbJ9OGKAiq
yRuCwKVgZviXMAAAANa29rbzEyNUBoZWFkMQECAwQFBg==
-----END OPENSSH PRIVATE KEY-----

50
mcp/lifespan.py Normal file
View File

@@ -0,0 +1,50 @@
import asyncssh
from contextlib import asynccontextmanager
from collections.abc import AsyncIterator
from dataclasses import dataclass
from starlette.applications import Starlette
# --- 1. SSH 连接参数 ---
REMOTE_HOST = '202.121.182.208'
REMOTE_USER = 'koko125'
# 确保路径正确Windows 路径建议使用正斜杠 '/'
PRIVATE_KEY_PATH = 'D:/tool/tool/id_rsa.txt'
INITIAL_WORKING_DIRECTORY = f'/cluster/home/{REMOTE_USER}/sandbox'
# --- 2. 定义共享上下文 ---
# 这个数据类将持有所有模块需要共享的资源
@dataclass
class SharedAppContext:
"""在整个服务器生命周期内持有的共享资源"""
ssh_connection: asyncssh.SSHClientConnection
sandbox_path: str
# --- 3. 创建生命周期管理器 ---
# 这是整个架构的核心负责在服务器启动时连接SSH在关闭时断开
@asynccontextmanager
async def shared_lifespan(app: Starlette) -> AsyncIterator[SharedAppContext]:
"""
为整个应用管理共享资源的生命周期。
"""
print("主应用启动,正在建立共享 SSH 连接...")
conn = None
try:
# 建立 SSH 连接
conn = await asyncssh.connect(
REMOTE_HOST,
username=REMOTE_USER,
client_keys=[PRIVATE_KEY_PATH]
)
print(f"SSH 连接到 {REMOTE_HOST} 成功!")
# 使用 yield 将创建好的共享资源上下文传递给 Starlette 应用
# 服务器会在此处暂停并开始处理请求
yield {"shared_context": SharedAppContext(ssh_connection=conn, sandbox_path=INITIAL_WORKING_DIRECTORY)}
finally:
# 当服务器关闭时yield 之后的代码会被执行
if conn:
conn.close()
await conn.wait_closed()
print("共享 SSH 连接已关闭。")

55
mcp/main.py Normal file
View File

@@ -0,0 +1,55 @@
# main.py
# 将 Materials CIF MCP 与 System Tools MCP 一起挂载到 Starlette。
# 关键点:
# - 在 lifespan 中启动每个 MCP 的 session_manager.run()(参考 SDK README 的 Starlette 挂载示例与 streamable_http_app 用法 [1]
# - 通过 Mount 指定各自的子路径(如 /system 与 /materials
import contextlib
from starlette.applications import Starlette
from starlette.routing import Mount
from system_tools import create_system_mcp
from materialproject_mcp import create_materials_mcp
from softBV_remake import create_softbv_mcp
from paper_search_mcp import create_paper_search_mcp
from topological_analysis_models import create_topological_analysis_mcp
# 创建 MCP 实例
system_mcp = create_system_mcp()
materials_mcp = create_materials_mcp()
softbv_mcp = create_softbv_mcp()
paper_search_mcp = create_paper_search_mcp()
topological_analysis_mcp = create_topological_analysis_mcp()
# 在 Starlette 的 lifespan 中启动 MCP 的 session manager
@contextlib.asynccontextmanager
async def lifespan(app: Starlette):
async with contextlib.AsyncExitStack() as stack:
await stack.enter_async_context(system_mcp.session_manager.run())
await stack.enter_async_context(materials_mcp.session_manager.run())
await stack.enter_async_context(softbv_mcp.session_manager.run())
await stack.enter_async_context(paper_search_mcp.session_manager.run())
await stack.enter_async_context(topological_analysis_mcp.session_manager.run())
yield # 服务器运行期间
# 退出时自动清理
# 挂载两个 MCP 的 Streamable HTTP App
app = Starlette(
lifespan=lifespan,
routes=[
Mount("/system", app=system_mcp.streamable_http_app()),
Mount("/materials", app=materials_mcp.streamable_http_app()),
Mount("/softBV", app=softbv_mcp.streamable_http_app()),
Mount("/papersearch",app=paper_search_mcp.streamable_http_app()),
Mount("/topologicalAnalysis",app=topological_analysis_mcp.streamable_http_app()),
],
)
# 启动命令(终端执行):
# uvicorn main:app --host 0.0.0.0 --port 8000
# 访问:
# http://localhost:8000/system
# http://localhost:8000/materials
# http://localhost:8000/softBV
# http://localhost:8000/papersearch
# http://localhost:8000/topologicalAnalysis
# 如果需要浏览器客户端访问CORS 暴露 Mcp-Session-Id请参考 README 中的 CORS 配置示例 [1]

513
mcp/materialproject_mcp.py Normal file
View File

@@ -0,0 +1,513 @@
# materials_mp_cif_mcp.py
#
# 说明:
# - 这是一个基于 FastMCP 的“Materials Project CIF”服务器提供
# 1) search_cifs使用 Materials Project 的 mp-api 进行检索两步summary 过滤 + materials 获取结构)
# 2) get_cif按材料 ID 获取结构并导出 CIF 文本pymatgen
# 3) filter_cifs在本地对候选进行硬筛选
# 4) download_cif_bulk批量下载 CIF带并发与进度
# 5) 资源cif://mp/{id} 直接读取 CIF 文本
#
# - 依赖pip install "mcp[cli]" mp_api pymatgen
# - API Key
# 默认从环境变量 MP_API_KEY 读取;若未设置则回退为你提供的 Key仅用于演示不建议在生产中硬编码[3]
# - MCP/Context/Lifespan 的使用方法参考 SDK 文档(工具、结构化输出、进度等)[1]
#
# 引用:
# - MPRester/MP_API_KEY 的用法与推荐的会话管理方式见 Materials Project 文档 [3]
# - summary.search 与 materials.search 的查询与 fields 限制示例见文档 [4]
# - FastMCP 的工具、资源、Context、结构化输出见 README 示例 [1]
import os
import asyncio
from dataclasses import dataclass
from typing import Any, Literal
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
import re
from pydantic import BaseModel, Field
from mp_api.client import MPRester # [3]
from pymatgen.core import Structure # 用于转 CIF
from mcp.server.fastmcp import FastMCP, Context # [1]
from mcp.server.session import ServerSession # [1]
# ========= 配置 =========
# 生产中请通过环境变量 MP_API_KEY 配置;此处为演示而提供回退到你给的 API Key不建议硬编码[3]
DEFAULT_MP_API_KEY = os.getenv("MP_API_KEY", "X0NqhxkFeupGy7k09HRmLS6PI4VmvTBW")
# ========= 数据模型 =========
class LatticeParameters(BaseModel):
a: float | None = Field(default=None, description="a 轴长度 (Å)")
b: float | None = Field(default=None, description="b 轴长度 (Å)")
c: float | None = Field(default=None, description="c 轴长度 (Å)")
alpha: float | None = Field(default=None, description="α 角 (度)")
beta: float | None = Field(default=None, description="β 角 (度)")
gamma: float | None = Field(default=None, description="γ 角 (度)")
class CIFCandidate(BaseModel):
id: str = Field(description="Materials Project ID如 mp-149")
source: Literal["mp"] = Field(default="mp", description="来源库")
chemical_formula: str | None = Field(default=None, description="化学式pretty")
space_group: str | None = Field(default=None, description="空间群符号")
cell_parameters: LatticeParameters | None = Field(default=None, description="晶格参数")
bandgap: float | None = Field(default=None, description="带隙 (eV)")
stability_energy: float | None = Field(default=None, description="E_above_hull (eV)")
magnetic_ordering: str | None = Field(default=None, description="磁性信息")
has_cif: bool = Field(default=True, description="是否存在可导出的 CIF只要有结构即可")
extras: dict[str, Any] | None = Field(default=None, description="原始文档片段(健壮解析)")
class SearchQuery(BaseModel):
# 基础化学约束
include_elements: list[str] | None = Field(default=None, description="包含元素集合(如 ['Si','O']") # [4]
exclude_elements: list[str] | None = Field(default=None, description="排除元素集合(如 ['Li']")
formula: str | None = Field(default=None, description="化学式片段(本地二次过滤)")
# 结构与性能约束summary 可过滤带隙与 E_above_hull[4]
bandgap_min: float | None = None
bandgap_max: float | None = None
e_above_hull_max: float | None = None
# 空间群、晶格范围(在本地二次过滤)
space_group: str | None = None
cell_a_min: float | None = None
cell_a_max: float | None = None
cell_b_min: float | None = None
cell_b_max: float | None = None
cell_c_min: float | None = None
cell_c_max: float | None = None
magnetic_ordering: str | None = None # 本地二次过滤
page: int = Field(default=1, ge=1, description="页码仅本地分页MP 端点请自行扩展)")
per_page: int = Field(default=50, ge=1, le=200, description="每页数量(本地分页)")
sort_by: str | None = Field(default=None, description="排序字段bandgap/e_above_hull等本地排序")
sort_order: Literal["asc", "desc"] | None = Field(default=None, description="排序方向")
class CIFText(BaseModel):
id: str
source: Literal["mp"]
text: str = Field(description="CIF 文件文本内容")
metadata: dict[str, Any] | None = Field(default=None, description="附加元数据(如结构来源)")
class FilterCriteria(BaseModel):
include_elements: list[str] | None = None
exclude_elements: list[str] | None = None
formula: str | None = None
space_group: str | None = None
cell_a_min: float | None = None
cell_a_max: float | None = None
cell_b_min: float | None = None
cell_b_max: float | None = None
cell_c_min: float | None = None
cell_c_max: float | None = None
bandgap_min: float | None = None
bandgap_max: float | None = None
e_above_hull_max: float | None = None
magnetic_ordering: str | None = None
class BulkDownloadItem(BaseModel):
id: str
source: Literal["mp"] = "mp"
success: bool
text: str | None = None
error: str | None = None
class BulkDownloadResult(BaseModel):
items: list[BulkDownloadItem]
total: int
succeeded: int
failed: int
# ========= Lifespan =========
@dataclass
class AppContext:
mpr: MPRester
api_key: str
@asynccontextmanager
async def lifespan(_server: FastMCP) -> AsyncIterator[AppContext]:
# 推荐使用 MP_API_KEY 环境变量;此处回退到提供的 Key 字符串 [3]
api_key = DEFAULT_MP_API_KEY
mpr = MPRester(api_key) # 会话管理在 mp-api 内部 [3]
try:
yield AppContext(mpr=mpr, api_key=api_key)
finally:
# mp-api 未强制要求手动关闭;若未来提供关闭方法,可在此清理 [3]
pass
# ========= 服务器 =========
mcp = FastMCP(
name="Materials Project CIF",
instructions="通过 Materials Project API 检索并导出 CIF 文件。",
lifespan=lifespan,
streamable_http_path="/", # 便于在 Starlette 下挂载到子路径 [1]
)
# ========= 工具实现 =========
def _match_elements_set(text: str | None, required: list[str] | None, excluded: list[str] | None) -> bool:
if text is None:
return not required and not excluded
s = text.upper()
if required:
for el in required:
if el.upper() not in s:
return False
if excluded:
for el in excluded:
if el.upper() in s:
return False
return True
def _lattice_from_structure(struct: Structure | None) -> LatticeParameters | None:
if struct is None:
return None
lat = struct.lattice
return LatticeParameters(
a=float(lat.a),
b=float(lat.b),
c=float(lat.c),
alpha=float(lat.alpha),
beta=float(lat.beta),
gamma=float(lat.gamma),
)
def _sg_symbol_from_doc(doc: Any) -> str | None:
"""
尝试从 materials.search 返回的文档中提取空间群符号。
MP 文档中 symmetry 通常包含 spacegroup 信息(不同版本字段路径可能变化)[4]
"""
try:
sym = getattr(doc, "symmetry", None)
if sym and getattr(sym, "spacegroup", None):
sg = sym.spacegroup
# 常见字段symbol 或 international
return getattr(sg, "symbol", None) or getattr(sg, "international", None)
except Exception:
pass
return None
async def _to_thread(fn, *args, **kwargs):
return await asyncio.to_thread(fn, *args, **kwargs)
def _parse_elements_from_formula(formula: str) -> list[str]:
"""从化学式中解析元素符号(简易版)。"""
if not formula:
return []
elems = re.findall(r"[A-Z][a-z]?", formula)
seen, ordered = set(), []
for e in elems:
if e not in seen:
seen.add(e)
ordered.append(e)
return ordered
@mcp.tool()
async def search_cifs(
query: SearchQuery,
ctx: Context[ServerSession, AppContext],
) -> list[CIFCandidate]:
"""
通过 summary 做初筛(限制返回数量),再用 materials 拉取当前页的结构与对称性。
修复点:
- 仅请求当前页chunk_size/per_page + num_chunks=1避免全库下载超时
- 绝不在返回值中包含 mp-api 文档对象或 Structureextras 只放可序列化轻量字段
"""
app = ctx.request_context.lifespan_context
mpr = app.mpr
# 将 formula 自动转为元素集合(若用户未显式提供)
include_elements = query.include_elements or _parse_elements_from_formula(query.formula or "")
# 至少要有一个约束,避免全库扫描
if not include_elements and not query.exclude_elements and query.bandgap_min is None \
and query.bandgap_max is None and query.e_above_hull_max is None:
raise ValueError("请至少提供一个过滤条件,例如 include_elements、bandgap 或 energy_above_hull 上限,以避免超时。")
# 1) summary 初筛:只取“当前页”的文档(轻字段)
summary_kwargs: dict[str, Any] = {}
if include_elements:
summary_kwargs["elements"] = include_elements # MP 支持按元素集合检索
if query.exclude_elements:
summary_kwargs["exclude_elements"] = query.exclude_elements
if query.bandgap_min is not None or query.bandgap_max is not None:
summary_kwargs["band_gap"] = (
float(query.bandgap_min) if query.bandgap_min is not None else None,
float(query.bandgap_max) if query.bandgap_max is not None else None,
)
if query.e_above_hull_max is not None:
summary_kwargs["energy_above_hull"] = (None, float(query.e_above_hull_max))
fields_summary = [
"material_id",
"formula_pretty",
"band_gap",
"energy_above_hull",
"ordering",
]
# 关键:限制 summary 返回规模
try:
summary_docs = await asyncio.to_thread(
mpr.materials.summary.search,
fields=fields_summary,
chunk_size=query.per_page, # 每块(页)大小
num_chunks=1, # 只取一块 => 只取当前页数量
**summary_kwargs,
)
except TypeError:
# 如果你的 mp-api 版本不支持上述参数,提示升级,避免一次性抓回十几万条导致超时
raise RuntimeError(
"当前 mp-api 版本不支持 chunk_size/num_chunks请升级 mp-apipip install -U mp_api"
)
if not summary_docs:
await ctx.debug("summary 检索结果为空")
return []
# 本地排序(如需)
results_all: list[CIFCandidate] = []
for sdoc in summary_docs:
e_hull = getattr(sdoc, "energy_above_hull", None)
bandgap = getattr(sdoc, "band_gap", None)
ordering = getattr(sdoc, "ordering", None)
# 暂不放 structure/doc 到 extras避免序列化错误
results_all.append(
CIFCandidate(
id=sdoc.material_id,
source="mp",
chemical_formula=getattr(sdoc, "formula_pretty", None),
space_group=None,
cell_parameters=None,
bandgap=float(bandgap) if bandgap is not None else None,
stability_energy=float(e_hull) if e_hull is not None else None,
magnetic_ordering=ordering,
has_cif=True,
extras={
"summary": {
"material_id": sdoc.material_id,
"formula_pretty": getattr(sdoc, "formula_pretty", None),
"band_gap": float(bandgap) if bandgap is not None else None,
"energy_above_hull": float(e_hull) if e_hull is not None else None,
"ordering": ordering,
}
},
)
)
# 本地分页summary 已限制返回条数,这里只做安全截断和排序
if query.sort_by:
reverse = query.sort_order == "desc"
key_map = {
"bandgap": lambda c: (c.bandgap if c.bandgap is not None else float("inf")),
"e_above_hull": lambda c: (c.stability_energy if c.stability_energy is not None else float("inf")),
"energy_above_hull": lambda c: (c.stability_energy if c.stability_energy is not None else float("inf")),
}
if query.sort_by in key_map:
results_all.sort(key=key_map[query.sort_by], reverse=reverse)
page_slice = results_all[: query.per_page]
page_ids = [c.id for c in page_slice]
# 2) 只为当前页材料获取结构与对称性
if not page_ids:
return []
fields_materials = ["material_id", "structure", "symmetry"]
materials_docs = await asyncio.to_thread(
mpr.materials.search,
material_ids=page_ids,
fields=fields_materials,
)
mat_by_id = {doc.material_id: doc for doc in materials_docs}
# 合并结构/空间群,仍然不返回不可序列化对象
merged: list[CIFCandidate] = []
for c in page_slice:
mdoc = mat_by_id.get(c.id)
struct = getattr(mdoc, "structure", None) if mdoc else None
# 空间群符号
sg = None
try:
sym = getattr(mdoc, "symmetry", None)
if sym and getattr(sym, "spacegroup", None):
sgobj = sym.spacegroup
sg = getattr(sgobj, "symbol", None) or getattr(sgobj, "international", None)
except Exception:
pass
c.space_group = sg
c.cell_parameters = _lattice_from_structure(struct)
c.has_cif = struct is not None
# 保持 extras 为轻量 JSON
c.extras = {
**(c.extras or {}),
"materials": {"has_structure": struct is not None},
}
merged.append(c)
await ctx.debug(f"返回 {len(merged)} 条(已限制每页 {query.per_page} 条)")
return merged
@mcp.tool()
async def get_cif(
id: str,
ctx: Context[ServerSession, AppContext],
) -> CIFText:
"""
使用 materials.search 获取结构并导出 CIF 文本。
- fields=["structure"] 以提升速度 [4]
- 结构到 CIF 通过 pymatgen 的 Structure.to(fmt="cif")
"""
app = ctx.request_context.lifespan_context
mpr = app.mpr
await ctx.info(f"获取结构并导出 CIF{id}")
docs = await _to_thread(mpr.materials.search, material_ids=[id], fields=["material_id", "structure"]) # [4]
if not docs:
raise ValueError(f"未找到材料:{id}")
doc = docs[0]
struct: Structure | None = getattr(doc, "structure", None)
if struct is None:
raise ValueError(f"材料 {id} 无结构数据,无法导出 CIF")
cif_text = struct.to(fmt="cif") # 直接输出 CIF 文本
return CIFText(id=id, source="mp", text=cif_text, metadata={"from": "materials.search"})
@mcp.tool()
async def filter_cifs(
candidates: list[CIFCandidate],
criteria: FilterCriteria,
ctx: Context[ServerSession, AppContext],
) -> list[CIFCandidate]:
def in_range(v: float | None, vmin: float | None, vmax: float | None) -> bool:
if v is None:
return vmin is None and vmax is None
if vmin is not None and v < vmin:
return False
if vmax is not None and v > vmax:
return False
return True
filtered: list[CIFCandidate] = []
for c in candidates:
# 元素集合近似匹配(基于化学式字符串)
if not _match_elements_set(c.chemical_formula, criteria.include_elements, criteria.exclude_elements):
continue
# 化学式包含判断
if criteria.formula:
if not c.chemical_formula or criteria.formula.upper() not in c.chemical_formula.upper():
continue
# 空间群
if criteria.space_group:
spg = (c.space_group or "").lower()
if criteria.space_group.lower() not in spg:
continue
# 晶格范围
cp = c.cell_parameters or LatticeParameters()
if not in_range(cp.a, criteria.cell_a_min, criteria.cell_a_max):
continue
if not in_range(cp.b, criteria.cell_b_min, criteria.cell_b_max):
continue
if not in_range(cp.c, criteria.cell_c_min, criteria.cell_c_max):
continue
# 带隙与稳定性
if not in_range(c.bandgap, criteria.bandgap_min, criteria.bandgap_max):
continue
if criteria.e_above_hull_max is not None:
if c.stability_energy is None or c.stability_energy > criteria.e_above_hull_max:
continue
# 磁性
if criteria.magnetic_ordering:
mo = "" # 当前未填充,保留兼容
if criteria.magnetic_ordering.upper() not in mo.upper():
continue
filtered.append(c)
await ctx.debug(f"筛选后:{len(filtered)} / 原始 {len(candidates)}")
return filtered
@mcp.tool()
async def download_cif_bulk(
ids: list[str],
concurrency: int = 4,
ctx: Context[ServerSession, AppContext] | None = None,
) -> BulkDownloadResult:
semaphore = asyncio.Semaphore(max(1, concurrency))
items: list[BulkDownloadItem] = []
async def worker(mid: str, idx: int) -> None:
nonlocal items
try:
async with semaphore:
# 复用当前 ctx
assert ctx is not None
cif = await get_cif(id=mid, ctx=ctx)
items.append(BulkDownloadItem(id=mid, success=True, text=cif.text))
except Exception as e:
items.append(BulkDownloadItem(id=mid, success=False, error=str(e)))
if ctx is not None:
progress = (idx + 1) / len(ids) if len(ids) > 0 else 1.0
await ctx.report_progress(progress=progress, total=1.0, message=f"{idx + 1}/{len(ids)}") # [1]
tasks = [worker(mid, i) for i, mid in enumerate(ids)]
await asyncio.gather(*tasks)
succeeded = sum(1 for it in items if it.success)
failed = len(items) - succeeded
return BulkDownloadResult(items=items, total=len(items), succeeded=succeeded, failed=failed)
# ========= 资源cif://mp/{id} =========
@mcp.resource("cif://mp/{id}")
async def cif_resource(id: str) -> str:
"""
基于 URI 的 CIF 内容读取。资源函数无法直接拿到 Context
这里简单地新的 MPRester 会话读取结构并转 CIF仅资源路径用途
"""
# 资源中使用一次性会话(不复用 lifespan适合偶发读取 [1]
mpr = MPRester(DEFAULT_MP_API_KEY) # [3]
try:
docs = mpr.materials.search(material_ids=[id], fields=["material_id", "structure"]) # [4]
if not docs or getattr(docs[0], "structure", None) is None:
raise ValueError(f"材料 {id} 无结构数据")
struct: Structure = docs[0].structure
return struct.to(fmt="cif")
finally:
pass
def create_materials_mcp() -> FastMCP:
"""供 Starlette 挂载的工厂函数。"""
return mcp
if __name__ == "__main__":
# 独立运行(非挂载模式),使用 Streamable HTTP 传输 [1]
mcp.run(transport="streamable-http")

172
mcp/paper_search_mcp.py Normal file
View File

@@ -0,0 +1,172 @@
# paper_search_mcp.py (重构版)
import os
import asyncio
from dataclasses import dataclass
from typing import Any
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from pydantic import BaseModel, Field
# 假设您的 search.py 提供了 PaperSearcher 类
from SearchPaperByEmbedding.search import PaperSearcher
from mcp.server.fastmcp import FastMCP, Context
from mcp.server.session import ServerSession
# ========= 1. 配置与常量 (修正版) =========
# 获取当前脚本文件所在的目录的绝对路径
_CURRENT_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
# 基于项目根目录构造资源文件的绝对路径
DATA_DIR = os.path.join(_CURRENT_SCRIPT_DIR,"SearchPaperByEmbedding")
PAPERS_JSON_PATH = os.path.join(DATA_DIR, "iclr2026_papers.json")
# 打印路径用于调试,确保它是正确的
print(f"DEBUG: Trying to load papers from: {PAPERS_JSON_PATH}")
MODEL_TYPE = os.getenv("EMBEDDING_MODEL_TYPE", "local")
# ========= 2. 数据模型 =========
class Paper(BaseModel):
"""单篇论文的详细信息"""
title: str
authors: list[str]
abstract: str
pdf_url: str | None = None
similarity_score: float = Field(description="与查询的相似度分数,越高越相关")
class SearchQuery(BaseModel):
"""论文搜索的查询结构"""
title: str
abstract: str | None = Field(None, description="可选的摘要,以提供更丰富的上下文")
# ========= 3. Lifespan (资源加载与管理) =========
@dataclass
class PaperSearchContext:
"""在服务器生命周期内共享的、已初始化的 PaperSearcher 实例"""
searcher: PaperSearcher
@asynccontextmanager
async def paper_search_lifespan(_server: FastMCP) -> AsyncIterator[PaperSearchContext]:
"""
FastMCP 生命周期:在服务器启动时初始化 PaperSearcher 并计算 Embeddings。
这确保了最耗时的部分只在启动时执行一次。
"""
print(f"正在使用 '{MODEL_TYPE}' 模型初始化论文搜索引擎...")
# 使用 to_thread 运行同步的初始化代码
def _initialize_searcher():
# 1. 初始化
searcher = PaperSearcher(
PAPERS_JSON_PATH,
model_type=MODEL_TYPE
)
# 2. 计算/加载 Embeddings
print("正在计算或加载论文 Embeddings...")
searcher.compute_embeddings()
print("Embeddings 加载完成。")
return searcher
searcher = await asyncio.to_thread(_initialize_searcher)
print("论文搜索引擎已准备就绪!")
try:
# 使用 yield 将创建好的共享资源上下文传递给 MCP
yield PaperSearchContext(searcher=searcher)
finally:
print("论文搜索引擎服务关闭。")
# ========= 4. MCP 服务器实例 =========
mcp = FastMCP(
name="Paper Search",
instructions="通过 embedding 相似度检索 ICLR 2026 论文的工具。",
lifespan=paper_search_lifespan,
streamable_http_path="/",
)
# ========= 5. 工具实现 =========
@mcp.tool()
async def search_papers(
queries: list[SearchQuery],
top_k: int = 3,
ctx: Context[ServerSession, PaperSearchContext] | None = None,
) -> list[Paper]:
"""
根据给定的一个或多个查询(包含标题和可选的摘要),使用语义相似度搜索论文。
返回最相关的 top_k 个结果。
"""
if ctx is None:
raise ValueError("Context is required for this operation.")
if not queries:
return []
app_ctx = ctx.request_context.lifespan_context
searcher = app_ctx.searcher
# 预处理查询,确保 title 和 abstract 至少有一个存在
examples_to_search = []
for q in queries:
query_dict = q.model_dump(exclude_none=True)
if "title" not in query_dict and "abstract" in query_dict:
query_dict["title"] = query_dict["abstract"]
elif "abstract" not in query_dict and "title" in query_dict:
query_dict["abstract"] = query_dict["title"]
examples_to_search.append(query_dict)
if not examples_to_search:
return []
query_display = queries[0].title or queries[0].abstract or "Empty Query"
await ctx.info(f"正在以查询 '{query_display}' 等内容搜索排名前 {top_k} 的论文...")
# 调用底层的 search 方法
results = await asyncio.to_thread(
searcher.search,
examples=examples_to_search,
top_k=top_k
)
# --- 关键修正:按照 search.py 返回的正确结构进行解析 ---
papers = []
for res in results:
paper_data = res.get('paper', {}) # 获取嵌套的论文信息字典
similarity = res.get('similarity', 0.0) # 获取相似度分数
papers.append(
Paper(
title=paper_data.get("title", "N/A"),
authors=paper_data.get("authors", []),
abstract=paper_data.get("abstract", ""),
# 注意:原始数据中 pdf url 的键是 'forum_url'
pdf_url=paper_data.get("forum_url"),
similarity_score=similarity # 使用正确的相似度分数
)
)
await ctx.debug(f"找到 {len(papers)} 篇相关论文。")
return papers
# ========= 6. 工厂函数 (用于 Starlette 集成) =========
def create_paper_search_mcp() -> FastMCP:
"""供 Starlette 挂载的工厂函数。"""
return mcp
if __name__ == "__main__":
# 如果直接运行此文件,则启动一个独立的 MCP 服务器
mcp.run(transport="streamable-http")

221
mcp/server.py Normal file
View File

@@ -0,0 +1,221 @@
import asyncssh
import os
from contextlib import asynccontextmanager
from collections.abc import AsyncIterator
from dataclasses import dataclass
from mcp.server.fastmcp import FastMCP, Context
from mcp.server.session import ServerSession
# --- 1. 使用您提供的参数 ---
REMOTE_HOST = '202.121.182.208'
REMOTE_USER = 'koko125'
PRIVATE_KEY_PATH = 'D:/tool/tool/id_rsa.txt' # Windows 路径建议使用 /
INITIAL_WORKING_DIRECTORY = '/cluster/home/koko125/sandbox'
# --- 2. 生命周期管理部分 (高级功能) ---
@dataclass
class AppContext:
"""用于在服务器生命周期内持有共享资源的上下文对象"""
ssh_connection: asyncssh.SSHClientConnection
current_path: str
@asynccontextmanager
async def app_lifespan(server: FastMCP) -> AsyncIterator[AppContext]:
"""在服务器启动时建立SSH连接在关闭时断开。"""
print("服务器启动中正在建立SSH连接...")
conn = None
try:
conn = await asyncssh.connect(
REMOTE_HOST,
username=REMOTE_USER,
client_keys=[PRIVATE_KEY_PATH]
)
print("SSH 连接成功!")
# <<< 在登录时执行 source bashrc 命令 >>>
# 注意bashrc 的效果(如环境变量)只会对这一个 conn.run() 调用生效。
# 对于非交互式shell更好的做法是把需要执行的命令串联起来。
# 但如果只是为了加载路径等,可以在后续命令中体现。
print("正在加载 .bashrc...")
# 为了让 bashrc 生效,我们需要在一个交互式 shell 中执行命令
await conn.run('source /cluster/home/koko125/.bashrc', check=True)
print(".bashrc 加载完成。")
print(f"初始工作目录设置为: {INITIAL_WORKING_DIRECTORY}")
yield AppContext(ssh_connection=conn, current_path=INITIAL_WORKING_DIRECTORY)
finally:
if conn:
conn.close()
await conn.wait_closed()
print("SSH 连接已关闭。")
# --- 3. 创建 MCP 服务器实例,并应用生命周期管理 ---
mcp = FastMCP("远程服务器工具", lifespan=app_lifespan,port=8000)
# --- 4. 保留的调试工具 (不安全) ---
@mcp.tool()
async def execute_remote_command(command: str, ctx: Context) -> str:
"""
【调试用】在远程服务器上执行一个任意的shell命令并返回其输出。
警告:此工具有安全风险,请勿在生产环境中使用。
"""
await ctx.info(f"准备在 {REMOTE_HOST} 上执行调试命令: '{command}'")
try:
# 每次都创建一个新连接,以确保环境隔离
async with asyncssh.connect(REMOTE_HOST, username=REMOTE_USER, client_keys=[PRIVATE_KEY_PATH]) as conn:
# 在执行命令前先 source bashrc
full_command = f'source /cluster/home/koko125/.bashrc && {command}'
result = await conn.run(full_command, check=True)
output = result.stdout.strip()
await ctx.info("调试命令成功执行。")
return output
except asyncssh.ProcessError as e:
error_message = f"命令执行失败: {e.stderr.strip()}"
await ctx.error(error_message)
return error_message
except Exception as e:
error_message = f"发生未知错误: {str(e)}"
await ctx.error(error_message)
return error_message
# --- 5. 新增的规范、安全的工具 ---
@mcp.tool()
async def list_files_in_sandbox(ctx: Context[ServerSession, AppContext], path: str = ".") -> str:
"""
【安全】列出远程服务器安全沙箱路径下的文件和目录。
Args:
path: 要查看的相对路径,默认为当前沙箱目录。
"""
try:
app_ctx = ctx.request_context.lifespan_context
conn = app_ctx.ssh_connection
# 防止目录穿越攻击
if ".." in path.split('/'):
return "错误: 不允许使用 '..' 访问上级目录。"
# 安全地拼接路径
target_path = os.path.join(app_ctx.current_path, path)
await ctx.info(f"正在列出安全路径: {target_path}")
# 为了让 .bashrc 的设置(如别名)生效,最好也加上 source
command_to_run = f'source /cluster/home/koko125/.bashrc && ls -l {conn.escape(target_path)}'
result = await conn.run(command_to_run, check=True)
return result.stdout.strip()
except Exception as e:
await ctx.error(f"执行 ls 命令失败: {e}")
return f"错误: {e}"
@mcp.tool()
async def create_directory_in_sandbox(ctx: Context[ServerSession, AppContext], directory_name: str) -> str:
"""
【安全】在远程服务器的安全沙箱内创建一个新目录。
Args:
directory_name: 要创建的目录的名称。
"""
try:
app_ctx = ctx.request_context.lifespan_context
conn = app_ctx.ssh_connection
# 安全检查:确保目录名不包含斜杠或 "..",以防止创建深层目录或进行目录穿越
if '/' in directory_name or ".." in directory_name:
return "错误: 目录名无效。不能包含 '/''..'"
target_path = os.path.join(app_ctx.current_path, directory_name)
await ctx.info(f"正在创建远程目录: {target_path}")
command_to_run = f'source /cluster/home/koko125/.bashrc && mkdir {conn.escape(target_path)}'
await conn.run(command_to_run, check=True)
return f"目录 '{directory_name}' 在沙箱中成功创建。"
except asyncssh.ProcessError as e:
error_message = f"创建目录失败: {e.stderr.strip()}"
await ctx.error(error_message)
return error_message
except Exception as e:
await ctx.error(f"创建目录时发生未知错误: {e}")
return f"错误: {e}"
@mcp.tool()
async def read_file_in_sandbox(ctx: Context[ServerSession, AppContext], file_path: str) -> str:
"""
【安全】读取远程服务器安全沙箱内指定文件的内容。
Args:
file_path: 相对于沙箱目录的文件路径。
"""
try:
app_ctx = ctx.request_context.lifespan_context
conn = app_ctx.ssh_connection
# 安全检查:防止目录穿越
if ".." in file_path.split('/'):
return "错误: 不允许使用 '..' 访问上级目录。"
target_path = os.path.join(app_ctx.current_path, file_path)
await ctx.info(f"正在读取远程文件: {target_path}")
async with conn.start_sftp_client() as sftp:
async with sftp.open(target_path, 'r') as f:
content = await f.read()
return content
except FileNotFoundError:
error_message = f"读取文件失败: 文件 '{file_path}' 不存在。"
await ctx.error(error_message)
return error_message
except Exception as e:
await ctx.error(f"读取文件时发生未知错误: {e}")
return f"错误: {e}"
@mcp.tool()
async def write_file_in_sandbox(ctx: Context[ServerSession, AppContext], file_path: str, content: str) -> str:
"""
【安全】向远程服务器安全沙箱内的文件写入内容。如果文件已存在,则会覆盖它。
Args:
file_path: 相对于沙箱目录的文件路径。
content: 要写入文件的文本内容。
"""
try:
app_ctx = ctx.request_context.lifespan_context
conn = app_ctx.ssh_connection
normalized_file_path = file_path.replace("\\", "/")
# 安全检查
if ".." in file_path.split('/'):
return "错误: 不允许使用 '..' 访问上级目录。"
target_path = os.path.join(app_ctx.current_path, normalized_file_path)
await ctx.info(f"正在向远程文件写入: {target_path}")
async with conn.start_sftp_client() as sftp:
async with sftp.open(target_path, 'w') as f:
await f.write(content)
return f"内容已成功写入到沙箱文件 '{file_path}'"
except Exception as e:
await ctx.error(f"写入文件时发生未知错误: {e}")
return f"错误: {e}"
# --- 运行服务器 ---
if __name__ == "__main__":
mcp.run(transport="streamable-http")

464
mcp/softBV.py Normal file
View File

@@ -0,0 +1,464 @@
# softbv_mcp.py
#
# 在远程服务器上激活 softBV 环境并执行计算(支持 --md 与 --gen-cube 两个专用工具)。
# - 生命周期:建立 SSH 连接并注入上下文
# - 激活环境source /cluster/home/koko125/script/softBV.sh
# - 工作目录:/cluster/home/koko125/sandbox
# - 可执行文件:/cluster/home/koko125/tool/softBV-GUI_linux/bin/softBV.x
#
# 依赖:
# pip install "mcp[cli]" asyncssh pydantic
#
# 用法Starlette 挂载示例见你现有 main.py导入 create_softbv_mcp 即可):
# from softbv_mcp import create_softbv_mcp
# softbv_mcp = create_softbv_mcp()
# Mount("/softbv", app=softbv_mcp.streamable_http_app())
#
# 可通过环境变量覆盖连接信息:
# REMOTE_HOST, REMOTE_USER, PRIVATE_KEY_PATH, SOFTBV_PROFILE, SOFTBV_BIN, DEFAULT_WORKDIR
import os
import posixpath
import asyncio
from dataclasses import dataclass
from socket import socket
from typing import Any
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
import asyncssh
from pydantic import BaseModel, Field
from mcp.server.fastmcp import FastMCP, Context
from mcp.server.session import ServerSession
from lifespan import REMOTE_HOST, REMOTE_USER, PRIVATE_KEY_PATH
def shell_quote(arg: str) -> str:
"""安全地把字符串作为单个 shell 参数POSIX"""
return "'" + str(arg).replace("'", "'\"'\"'") + "'"
# 如果你已经定义了这些常量与路径,可保持复用;也可按需改为你自己的配置源
# 固定的 softBV 环境信息(可用环境变量覆盖)
SOFTBV_PROFILE = os.getenv("SOFTBV_PROFILE", "/cluster/home/koko125/script/softBV.sh")
SOFTBV_BIN = os.getenv("SOFTBV_BIN", "/cluster/home/koko125/tool/softBV-GUI_linux/bin/softBV.x")
DEFAULT_WORKDIR = os.getenv("DEFAULT_WORKDIR", "/cluster/home/koko125/sandbox")
@dataclass
class SoftBVContext:
ssh_connection: asyncssh.SSHClientConnection
workdir: str
profile: str
bin_path: str
@asynccontextmanager
async def softbv_lifespan(_server) -> AsyncIterator[SoftBVContext]:
"""
FastMCP 生命周期:建立 SSH 连接并注入 softBV 上下文。
- 不再做额外的 DNS 解析或自定义异步步骤,避免 socket.SOCK_STREAM 的环境覆盖问题
- 仅负责连接与清理;工具中通过 ctx.request_context.lifespan_context 访问该上下文 [1]
"""
# 允许用环境变量覆盖连接信息
host = os.getenv("REMOTE_HOST", REMOTE_HOST)
user = os.getenv("REMOTE_USER", REMOTE_USER)
key_path = os.getenv("PRIVATE_KEY_PATH", PRIVATE_KEY_PATH)
conn: asyncssh.SSHClientConnection | None = None
try:
conn = await asyncssh.connect(
host,
username=user,
client_keys=[key_path],
known_hosts=None, # 如需主机指纹校验,可移除此参数
connect_timeout=15, # 避免长时间挂起
)
yield SoftBVContext(
ssh_connection=conn,
workdir=DEFAULT_WORKDIR,
profile=SOFTBV_PROFILE,
bin_path=SOFTBV_BIN,
)
finally:
if conn:
conn.close()
await conn.wait_closed()
async def run_in_softbv_env(
conn: asyncssh.SSHClientConnection,
profile_path: str,
cmd: str,
cwd: str | None = None,
check: bool = True,
) -> asyncssh.SSHCompletedProcess:
"""
在远端 bash 会话中激活 softBV 环境并执行 cmd。
如提供 cwd则先 cd 到该目录。
"""
parts = []
if cwd:
parts.append(f"cd {shell_quote(cwd)}")
parts.append(f"source {shell_quote(profile_path)}")
parts.append(cmd)
composite = "; ".join(parts)
full = f"bash -lc {shell_quote(composite)}"
return await conn.run(full, check=check)
# ===== MCP 服务器 =====
mcp = FastMCP(
name="softBV Tools",
instructions="在远程服务器上激活 softBV 环境并执行相关计算的工具集。",
lifespan=softbv_lifespan,
streamable_http_path="/",
)
# ===== 辅助:列目录用于识别新生成文件 =====
async def _listdir(conn: asyncssh.SSHClientConnection, path: str) -> list[str]:
async with conn.start_sftp_client() as sftp:
try:
return await sftp.listdir(path)
except Exception:
return []
# ===== 结构化输入模型:--md =====
class SoftBVMDArgs(BaseModel):
input_cif: str = Field(description="远程 CIF 文件路径(相对或绝对),作为 --md 的输入")
# 位置参数按帮助中的顺序None 表示不提供,让程序使用默认)
type: str | None = Field(default=None, description="conducting ion 类型(例如 'Li'")
os: int | None = Field(default=None, description="conducting ion 氧化态(整数)")
sf: float | None = Field(default=None, description="screening factor非正值使用默认")
temperature: float | None = Field(default=None, description="温度 K非正值默认 300")
t_end: float | None = Field(default=None, description="生产时间 ps非正值默认 10.0")
t_equil: float | None = Field(default=None, description="平衡时间 ps非正值默认 2.0")
dt: float | None = Field(default=None, description="时间步长 ps非正值默认 0.001")
t_log: float | None = Field(default=None, description="采样间隔 ps非正值每 100 步)")
cwd: str | None = Field(default=None, description="远程工作目录(默认使用生命周期中的 workdir")
def _build_md_cmd(bin_path: str, args: SoftBVMDArgs, workdir: str) -> str:
input_abs = args.input_cif
if not input_abs.startswith("/"):
input_abs = posixpath.normpath(posixpath.join(workdir, args.input_cif))
parts: list[str] = [shell_quote(bin_path), shell_quote("--md"), shell_quote(input_abs)]
for val in [args.type, args.os, args.sf, args.temperature, args.t_end, args.t_equil, args.dt, args.t_log]:
if val is not None:
parts.append(shell_quote(str(val)))
return " ".join(parts)
# ===== 结构化输入模型:--gen-cube =====
class SoftBVGenCubeArgs(BaseModel):
input_cif: str = Field(description="远程 CIF 文件路径(相对或绝对),作为 --gen-cube 的输入")
type: str | None = Field(default=None, description="conducting ion 类型(如 'Li'")
os: int | None = Field(default=None, description="conducting ion 氧化态(整数)")
sf: float | None = Field(default=None, description="screening factor非正值使用默认")
resolution: float | None = Field(default=None, description="体素分辨率(默认约 0.1")
ignore_conducting_ion: bool = Field(default=False, description="flag:ignore_conducting_ion")
periodic: bool = Field(default=True, description="flag:periodic默认 True")
output_name: str | None = Field(default=None, description="输出文件名前缀(可选)")
cwd: str | None = Field(default=None, description="远程工作目录(默认使用生命周期中的 workdir")
def _build_gen_cube_cmd(bin_path: str, args: SoftBVGenCubeArgs, workdir: str) -> str:
input_abs = args.input_cif
if not input_abs.startswith("/"):
input_abs = posixpath.normpath(posixpath.join(workdir, args.input_cif))
parts: list[str] = [shell_quote(bin_path), shell_quote("--gen-cube"), shell_quote(input_abs)]
for val in [args.type, args.os, args.sf, args.resolution]:
if val is not None:
parts.append(shell_quote(str(val)))
if args.ignore_conducting_ion:
parts.append(shell_quote("--flag:ignore_conducting_ion"))
if args.periodic:
parts.append(shell_quote("--flag:periodic"))
if args.output_name:
parts.append(shell_quote(args.output_name))
return " ".join(parts)
# ===== 工具:环境信息检查 =====
# 工具:环境信息检查(修复版,避免超时)
@mcp.tool()
async def softbv_info(ctx: Context[ServerSession, SoftBVContext]) -> dict[str, Any]:
"""
快速自检:
- SFTP 检查工作目录、激活脚本、二进制是否存在/可执行(无需运行 softBV.x
- 激活环境后仅输出标记与当前工作目录,避免长输出或阻塞
"""
app = ctx.request_context.lifespan_context
conn = app.ssh_connection
# 1) 通过 SFTP 快速检查文件与目录状态(不会长时间阻塞)
def stat_path_safe(path: str) -> dict[str, Any]:
return {"exists": False, "is_exec": False, "size": None}
workdir_info = stat_path_safe(app.workdir)
profile_info = stat_path_safe(app.profile)
bin_info = stat_path_safe(app.bin_path)
try:
async with conn.start_sftp_client() as sftp:
# workdir
try:
attrs = await sftp.stat(app.workdir)
workdir_info["exists"] = True
workdir_info["size"] = int(attrs.size or 0)
except Exception:
pass
# profile
try:
attrs = await sftp.stat(app.profile)
profile_info["exists"] = True
profile_info["size"] = int(attrs.size or 0)
perms = int(attrs.permissions or 0)
profile_info["is_exec"] = bool(perms & 0o111)
except Exception:
pass
# bin
try:
attrs = await sftp.stat(app.bin_path)
bin_info["exists"] = True
bin_info["size"] = int(attrs.size or 0)
perms = int(attrs.permissions or 0)
bin_info["is_exec"] = bool(perms & 0o111)
except Exception:
pass
except Exception as e:
await ctx.warning(f"SFTP 检查失败: {e}")
# 2) 激活环境并做极简命令(避免 softBV.x --help 的长输出)
# 仅返回当前用户、PWD 与二进制可执行判断;不实际运行 softBV.x
cmd = "echo __SOFTBV_READY__ && echo $USER && pwd && (test -x " + shell_quote(app.bin_path) + " && echo __BIN_OK__ || echo __BIN_NOT_EXEC__)"
proc = await run_in_softbv_env(conn, app.profile, cmd=cmd, cwd=app.workdir, check=False)
# 解析输出行
lines = proc.stdout.splitlines() if proc.stdout else []
ready = "__SOFTBV_READY__" in lines
user = None
pwd = None
bin_ok = "__BIN_OK__" in lines
# 尝试定位 user/pwdready 之后的两行)
if ready:
idx = lines.index("__SOFTBV_READY__")
if len(lines) > idx + 1:
user = lines[idx + 1].strip()
if len(lines) > idx + 2:
pwd = lines[idx + 2].strip()
result = {
"host": os.getenv("REMOTE_HOST", REMOTE_HOST),
"user": os.getenv("REMOTE_USER", REMOTE_USER),
"workdir": app.workdir,
"profile": app.profile,
"bin_path": app.bin_path,
"sftp_check": {
"workdir": workdir_info,
"profile": profile_info,
"bin": bin_info,
},
"activate_ready": ready,
"pwd": pwd,
"bin_is_executable": bin_ok or bin_info["is_exec"],
"exit_status": proc.exit_status,
"stderr_head": "\n".join(proc.stderr.splitlines()[:10]) if proc.stderr else "",
}
# 友好日志
if not ready:
await ctx.warning("softBV 环境未就绪(可能 source 脚本路径问题或权限不足)")
if not result["bin_is_executable"]:
await ctx.warning("softBV 二进制不可执行或不存在,请检查 bin_path 与权限chmod +x")
if proc.exit_status != 0 and not proc.stderr:
await ctx.debug("命令非零退出但无 stderr可能是某些子测试返回非零导致")
return result
# ===== 工具:--md =====
@mcp.tool()
async def softbv_md(req: SoftBVMDArgs, ctx: Context[ServerSession, SoftBVContext]) -> dict[str, Any]:
"""
执行 softBV.x --md返回结构化结果
- cmd/cwd/exit_status/stdout/stderr
- new_files执行后新增文件列表便于定位输出
"""
app = ctx.request_context.lifespan_context
conn = app.ssh_connection
workdir = req.cwd or app.workdir
cmd = _build_md_cmd(app.bin_path, req, workdir)
await ctx.info(f"softBV --md 执行: {cmd} (cwd={workdir})")
pre_list = await _listdir(conn, workdir)
proc = await run_in_softbv_env(conn, app.profile, cmd=cmd, cwd=workdir, check=False)
post_list = await _listdir(conn, workdir)
new_files = sorted(set(post_list) - set(pre_list))
if proc.exit_status == 0:
await ctx.debug(f"--md 成功,新文件 {len(new_files)}")
else:
await ctx.warning(f"--md 非零退出: {proc.exit_status}")
return {
"cmd": cmd,
"cwd": workdir,
"exit_status": proc.exit_status,
"stdout": proc.stdout,
"stderr": proc.stderr,
"new_files": new_files,
}
async def run_in_softbv_env_stream(
conn: asyncssh.SSHClientConnection,
profile_path: str,
cmd: str,
cwd: str | None = None,
) -> asyncssh.SSHClientProcess:
parts = []
if cwd:
parts.append(f"cd {shell_quote(cwd)}")
parts.append(f"source {shell_quote(profile_path)} >/dev/null 2>&1 || true")
parts.append(cmd)
composite = "; ".join(parts)
full = f"bash -lc {shell_quote(composite)}"
# 不阻塞,返回进程句柄
proc = await conn.create_process(full)
return proc
# 轮询目录,识别新生成文件
async def _listdir(conn: asyncssh.SSHClientConnection, path: str) -> list[str]:
async with conn.start_sftp_client() as sftp:
try:
return await sftp.listdir(path)
except Exception:
return []
# 轮询日志文件大小(作为心跳/粗略进度依据)
async def _stat_size(conn: asyncssh.SSHClientConnection, path: str) -> int | None:
async with conn.start_sftp_client() as sftp:
try:
attrs = await sftp.stat(path)
return int(attrs.size or 0)
except Exception:
return None
# 构造 gen-cube 命令(保持你之前的参数拼接逻辑)
def _build_gen_cube_cmd(bin_path: str, args: SoftBVGenCubeArgs, workdir: str, log_path: str | None = None) -> str:
input_abs = args.input_cif
if not input_abs.startswith("/"):
input_abs = posixpath.normpath(posixpath.join(workdir, args.input_cif))
parts: list[str] = [shell_quote(bin_path), shell_quote("--gen-cube"), shell_quote(input_abs)]
for val in [args.type, args.os, args.sf, args.resolution]:
if val is not None:
parts.append(shell_quote(str(val)))
if args.ignore_conducting_ion:
parts.append(shell_quote("--flag:ignore_conducting_ion"))
if args.periodic:
parts.append(shell_quote("--flag:periodic"))
if args.output_name:
parts.append(shell_quote(args.output_name))
cmd = " ".join(parts)
# 将输出重定向到日志,便于轮询
if log_path:
cmd = f"{cmd} > {shell_quote(log_path)} 2>&1"
return cmd
# 修复版:长时运行的 softbv_gen_cube带心跳与超时保护
@mcp.tool()
async def softbv_gen_cube(req: SoftBVGenCubeArgs, ctx: Context[ServerSession, SoftBVContext]) -> dict[str, Any]:
"""
执行 softBV.x --gen-cube支持长任务心跳避免在 <25min 内被客户端强制超时。
- 每隔 10s 上报一次进度(心跳),包含已用时/日志大小/新增文件数
- 结束后返回 stdout_head来自日志文件片段、stderr_head如有、exit_status、新增文件
"""
app = ctx.request_context.lifespan_context
conn = app.ssh_connection
workdir = req.cwd or app.workdir
# 预先记录目录内容,用于结束后差集
before = await _listdir(conn, workdir)
# 远端日志文件路径(按时间戳命名)
import time
log_name = f"softbv_gencube_{int(time.time())}.log"
log_path = posixpath.join(workdir, log_name)
# 启动长任务,不阻塞当前协程
cmd = _build_gen_cube_cmd(app.bin_path, req, workdir, log_path=log_path)
await ctx.info(f"启动 --gen-cube: {cmd}")
proc = await run_in_softbv_env_stream(conn, app.profile, cmd=cmd, cwd=workdir)
# 心跳循环:直到进程退出
start_ts = time.time()
heartbeat_sec = 10 # 每 10 秒发送一次心跳
max_guard_min = 60 # 保险上限(服务端不主动终止;如客户端有限制可调大)
try:
while True:
# 进程是否已退出
if proc.exit_status is not None:
break
# 采集状态:已用时、日志大小、新增文件数
elapsed = time.time() - start_ts
log_size = await _stat_size(conn, log_path)
now_files = await _listdir(conn, workdir)
new_files_count = len(set(now_files) - set(before))
# 这里无法获知真实百分比,使用“心跳/已用时提示”
await ctx.report_progress(
progress=min(elapsed / (25 * 60), 0.99), # 以 25min 为目标上限做近似刻度
total=1.0,
message=f"gen-cube 运行中: 已用时 {int(elapsed)}s, 日志 {log_size or 0}B, 新文件 {new_files_count}",
)
# 避免客户端超时:持续心跳即可。[1]
await asyncio.sleep(heartbeat_sec)
# 简易守护:超过 max_guard_min 仍未结束也不强制中断(由远端决定)
if elapsed > max_guard_min * 60:
await ctx.warning("任务已超过守护上限时间,仍在运行(未强制中断)。如需更长时间,请增大上限。")
# 不 break继续等待交由远端任务完成
finally:
# 等待进程真正结束(如果已结束,这里是快速返回)
await proc.wait()
# 结束后采集结果
exit_status = proc.exit_status
after = await _listdir(conn, workdir)
new_files = sorted(set(after) - set(before))
# 读取日志片段(头/尾),帮助定位输出
async with conn.start_sftp_client() as sftp:
head = ""
tail = ""
try:
async with sftp.open(log_path, "rb") as f:
content = await f.read()
text = content.decode("utf-8", errors="replace")
lines = text.splitlines()
head = "\n".join(lines[:40])
tail = "\n".join(lines[-40:])
except Exception:
pass
# 输出结构化结果
result = {
"cmd": cmd,
"cwd": workdir,
"exit_status": exit_status,
"log_file": log_path,
"stdout_head": head, # 代替一次性 stdout避免大输出
"stderr_head": "", # 统一日志到文件stderr_head 可留空
"new_files": new_files,
"elapsed_sec": int(time.time() - start_ts),
}
if exit_status == 0:
await ctx.info(f"gen-cube 完成,用时 {result['elapsed_sec']}s新文件 {len(new_files)}")
else:
await ctx.warning(f"gen-cube 退出码 {exit_status},请查看日志 {log_path}")
return result
def create_softbv_mcp() -> FastMCP:
"""供外部Starlette导入的工厂函数。"""
return mcp
if __name__ == "__main__":
mcp.run(transport="streamable-http")

1659
mcp/softBV_remake.py Normal file

File diff suppressed because it is too large Load Diff

407
mcp/system_tools.py Normal file
View File

@@ -0,0 +1,407 @@
# system_tools.py
import posixpath
import stat as pystat
from shlex import shlex
from typing import Any
from contextlib import asynccontextmanager
from collections.abc import AsyncIterator
import asyncssh
from mcp.server.fastmcp import FastMCP, Context
from mcp.server.session import ServerSession
# 在 system_tools.py 顶部添加
def shell_quote(arg: str) -> str:
"""
安全地把字符串作为单个 shell 参数:
- 外层用单引号包裹
- 内部的单引号 ' 替换为 '\'' 序列
适用于远端 Linux shell 命令拼接
"""
return "'" + arg.replace("'", "'\"'\"'") + "'"
# 复用你的配置与数据类
from lifespan import (
SharedAppContext,
REMOTE_HOST,
REMOTE_USER,
PRIVATE_KEY_PATH,
INITIAL_WORKING_DIRECTORY,
)
# —— 1) 定义 FastMCP 的生命周期,在启动时建立 SSH 连接,关闭时断开 ——
@asynccontextmanager
async def system_lifespan(_server: FastMCP) -> AsyncIterator[SharedAppContext]:
"""
FastMCP 生命周期:建立并注入共享的 SSH 连接与沙箱根路径。
说明:这是 MCP 自己的生命周期,工具里通过 ctx.request_context.lifespan_context 访问。
"""
conn: asyncssh.SSHClientConnection | None = None
try:
# 建立 SSH 连接
conn = await asyncssh.connect(
REMOTE_HOST,
username=REMOTE_USER,
client_keys=[PRIVATE_KEY_PATH],
)
# 将类型安全的共享上下文注入 MCP 生命周期
yield SharedAppContext(ssh_connection=conn, sandbox_path=INITIAL_WORKING_DIRECTORY)
finally:
# 关闭连接
if conn:
conn.close()
await conn.wait_closed()
def create_system_mcp() -> FastMCP:
"""创建一个包含系统操作工具的 MCP 实例。"""
system_mcp = FastMCP(
name="System Tools",
instructions="用于在远程服务器上进行基本文件和目录操作的工具集。",
streamable_http_path="/",
lifespan=system_lifespan, # 关键:把生命周期传给 FastMCP [1]
)
def _safe_join(sandbox_root: str, relative_path: str) -> str:
"""
将用户提供的相对路径映射到沙箱根目录内的规范化绝对路径。
- 统一使用 POSIX 语义(远端 Linux)
- 禁止使用以 '/' 开头的绝对路径
- 禁止 '..' 越界,确保最终路径仍在沙箱内
"""
rel = (relative_path or ".").strip()
# 禁止绝对路径,转为相对
if rel.startswith("/"):
rel = rel.lstrip("/")
# 规范化拼接
combined = posixpath.normpath(posixpath.join(sandbox_root, rel))
# 统一尾部斜杠处理,避免边界判断遗漏
root_norm = sandbox_root.rstrip("/")
# 确保仍在沙箱内
if combined != root_norm and not combined.startswith(root_norm + "/"):
raise ValueError("路径越界:仅允许访问沙箱目录内部")
# 禁止路径中出现 '..'(进一步加固)
parts = [p for p in combined.split("/") if p]
if ".." in parts:
raise ValueError("非法路径:不允许使用 '..' 跨目录")
return combined
# —— 重构后的各工具:统一用类型安全的 ctx 与 app_ctx 访问共享资源 ——
@system_mcp.tool()
async def list_files(ctx: Context[ServerSession, SharedAppContext], path: str = ".") -> list[dict[str, Any]]:
"""
列出远程沙箱目录中的文件和子目录(结构化输出)。
Returns:
list[dict]: [{name, path, is_dir, size, permissions, mtime}]
"""
try:
app_ctx = ctx.request_context.lifespan_context
conn = app_ctx.ssh_connection
sandbox_root = app_ctx.sandbox_path
target = _safe_join(sandbox_root, path)
await ctx.info(f"列出目录:{target}")
items: list[dict[str, Any]] = []
async with conn.start_sftp_client() as sftp:
names = await sftp.listdir(target) # list[str]
for name in names:
item_abs = posixpath.join(target, name)
attrs = None
try:
attrs = await sftp.stat(item_abs)
except Exception:
pass
perms = int(attrs.permissions or 0) if attrs else 0
is_dir = bool(pystat.S_ISDIR(perms))
size = int(attrs.size or 0) if attrs and attrs.size is not None else 0
mtime = int(attrs.mtime or 0) if attrs and attrs.mtime is not None else 0
items.append({
"name": name,
"path": posixpath.join(path.rstrip("/"), name) if path != "/" else name,
"is_dir": is_dir,
"size": size,
"permissions": perms,
"mtime": mtime,
})
await ctx.debug(f"目录项数量:{len(items)}")
return items
except FileNotFoundError:
msg = f"目录不存在或不可访问:{path}"
await ctx.error(msg)
return [{"error": msg}]
except Exception as e:
msg = f"list_files 失败:{e}"
await ctx.error(msg)
return [{"error": msg}]
@system_mcp.tool()
async def read_file(ctx: Context[ServerSession, SharedAppContext], file_path: str, encoding: str = "utf-8") -> str:
"""
读取远程沙箱内指定文件内容。
"""
try:
app_ctx = ctx.request_context.lifespan_context
conn = app_ctx.ssh_connection
sandbox_root = app_ctx.sandbox_path
target = _safe_join(sandbox_root, file_path)
await ctx.info(f"读取文件:{target}")
async with conn.start_sftp_client() as sftp:
async with sftp.open(target, "rb") as f:
data = await f.read()
try:
content = data.decode(encoding)
except Exception:
content = data.decode(encoding, errors="replace")
return content
except FileNotFoundError:
msg = f"读取失败:文件不存在 '{file_path}'"
await ctx.error(msg)
return msg
except Exception as e:
msg = f"read_file 失败:{e}"
await ctx.error(msg)
return msg
@system_mcp.tool()
async def write_file(
ctx: Context[ServerSession, SharedAppContext],
file_path: str,
content: str,
encoding: str = "utf-8",
create_parents: bool = True,
) -> dict[str, Any]:
"""
写入远程沙箱文件(默认按需创建父目录)。
"""
try:
app_ctx = ctx.request_context.lifespan_context # 类型化的 SharedAppContext
conn = app_ctx.ssh_connection
sandbox_root = app_ctx.sandbox_path
target = _safe_join(sandbox_root, file_path)
await ctx.info(f"写入文件:{target}")
if create_parents:
parent = posixpath.dirname(target)
if parent and parent != sandbox_root:
await conn.run(f"mkdir -p {shell_quote(parent)}", check=True)
data = content.encode(encoding)
async with conn.start_sftp_client() as sftp:
async with sftp.open(target, "wb") as f:
await f.write(data)
await ctx.debug(f"写入完成:{len(data)} 字节")
return {"path": file_path, "bytes_written": len(data)}
except Exception as e:
msg = f"write_file 失败:{e}"
await ctx.error(msg)
return {"error": msg}
@system_mcp.tool()
async def make_dir(ctx: Context[ServerSession, SharedAppContext], path: str, parents: bool = True) -> str:
try:
app_ctx = ctx.request_context.lifespan_context
conn = app_ctx.ssh_connection
sandbox_root = app_ctx.sandbox_path
target = _safe_join(sandbox_root, path)
if parents:
# 单行命令mkdir -p
await conn.run(f"mkdir -p {shell_quote(target)}", check=True)
else:
async with conn.start_sftp_client() as sftp:
await sftp.mkdir(target)
await ctx.info(f"目录已创建: {target}")
return f"目录已创建: {path}"
except Exception as e:
await ctx.error(f"创建目录失败: {e}")
return f"错误: {e}"
@system_mcp.tool()
async def delete_file(ctx: Context[ServerSession, SharedAppContext], file_path: str) -> str:
"""
删除文件(非目录)。
"""
try:
app_ctx = ctx.request_context.lifespan_context
conn = app_ctx.ssh_connection
sandbox_root = app_ctx.sandbox_path
target = _safe_join(sandbox_root, file_path)
async with conn.start_sftp_client() as sftp:
await sftp.remove(target)
await ctx.info(f"文件已删除: {target}")
return f"文件已删除: {file_path}"
except FileNotFoundError:
msg = f"删除失败:文件不存在 '{file_path}'"
await ctx.error(msg)
return msg
except Exception as e:
await ctx.error(f"删除文件失败: {e}")
return f"错误: {e}"
@system_mcp.tool()
async def delete_dir(
ctx: Context[ServerSession, SharedAppContext],
dir_path: str,
recursive: bool = False,
) -> str:
"""
删除目录。
"""
try:
app_ctx = ctx.request_context.lifespan_context # 类型化的 SharedAppContext
conn = app_ctx.ssh_connection
sandbox_root = app_ctx.sandbox_path
target = _safe_join(sandbox_root, dir_path)
if recursive:
await conn.run(f"rm -rf {shell_quote(target)}", check=True)
else:
async with conn.start_sftp_client() as sftp:
await sftp.rmdir(target)
await ctx.info(f"目录已删除: {target}")
return f"目录已删除: {dir_path}"
except FileNotFoundError:
msg = f"删除失败:目录不存在 '{dir_path}' 或非空"
await ctx.error(msg)
return msg
except Exception as e:
await ctx.error(f"删除目录失败: {e}")
return f"错误: {e}"
@system_mcp.tool()
async def move_path(ctx: Context[ServerSession, SharedAppContext], src: str, dst: str,
overwrite: bool = True) -> str:
try:
app_ctx = ctx.request_context.lifespan_context
conn = app_ctx.ssh_connection
sandbox_root = app_ctx.sandbox_path
src_abs = _safe_join(sandbox_root, src)
dst_abs = _safe_join(sandbox_root, dst)
flag = "-f" if overwrite else ""
# 单行命令mv
cmd = f"mv {flag} {shell_quote(src_abs)} {shell_quote(dst_abs)}".strip()
await conn.run(cmd, check=True)
await ctx.info(f"已移动: {src_abs} -> {dst_abs}")
return f"已移动: {src} -> {dst}"
except FileNotFoundError:
msg = f"移动失败:源不存在 '{src}'"
await ctx.error(msg)
return msg
except Exception as e:
await ctx.error(f"移动失败: {e}")
return f"错误: {e}"
@system_mcp.tool()
async def copy_path(
ctx: Context[ServerSession, SharedAppContext],
src: str,
dst: str,
recursive: bool = True,
overwrite: bool = True,
) -> str:
try:
app_ctx = ctx.request_context.lifespan_context
conn = app_ctx.ssh_connection
sandbox_root = app_ctx.sandbox_path
src_abs = _safe_join(sandbox_root, src)
dst_abs = _safe_join(sandbox_root, dst)
flags = []
if recursive:
flags.append("-r")
if overwrite:
flags.append("-f")
# 单行命令cp
cmd = " ".join(["cp"] + flags + [shell_quote(src_abs), shell_quote(dst_abs)])
await conn.run(cmd, check=True)
await ctx.info(f"已复制: {src_abs} -> {dst_abs}")
return f"已复制: {src} -> {dst}"
except FileNotFoundError:
msg = f"复制失败:源不存在 '{src}'"
await ctx.error(msg)
return msg
except Exception as e:
await ctx.error(f"复制失败: {e}")
return f"错误: {e}"
@system_mcp.tool()
async def exists(ctx: Context[ServerSession, SharedAppContext], path: str) -> bool:
"""
判断路径(文件/目录)是否存在。
"""
try:
app_ctx = ctx.request_context.lifespan_context
conn = app_ctx.ssh_connection
sandbox_root = app_ctx.sandbox_path
target = _safe_join(sandbox_root, path)
async with conn.start_sftp_client() as sftp:
await sftp.stat(target)
return True
except FileNotFoundError:
return False
except Exception as e:
await ctx.error(f"exists 检查失败: {e}")
return False
@system_mcp.tool()
async def stat_path(ctx: Context[ServerSession, SharedAppContext], path: str) -> dict:
"""
查看远程路径属性(结构化输出)。
"""
try:
app_ctx = ctx.request_context.lifespan_context
conn = app_ctx.ssh_connection
sandbox_root = app_ctx.sandbox_path
target = _safe_join(sandbox_root, path)
async with conn.start_sftp_client() as sftp:
attrs = await sftp.stat(target)
perms = attrs.permissions or 0
return {
"path": target,
"size": int(attrs.size or 0),
"is_dir": bool(pystat.S_ISDIR(perms)),
"permissions": perms, # 九进制权限位,例:0o755
"mtime": int(attrs.mtime or 0), # 秒
}
except FileNotFoundError:
msg = f"路径不存在: {path}"
await ctx.error(msg)
return {"error": msg}
except Exception as e:
await ctx.error(f"stat 失败: {e}")
return {"error": str(e)}
return system_mcp

19
mcp/test_tools.py Normal file
View File

@@ -0,0 +1,19 @@
# test_tools.py
from mcp.server.fastmcp import FastMCP
def create_test_mcp() -> FastMCP:
"""创建一个只包含最简单工具的 MCP 实例,用于测试连接。"""
test_mcp = FastMCP(
name="Test Tools",
instructions="用于测试服务器连接是否通畅。",
streamable_http_path="/"
)
@test_mcp.tool()
def ping() -> str:
"""一个简单的工具,用于确认服务器是否响应。"""
return "pong"
return test_mcp

View File

@@ -0,0 +1,185 @@
# topological_analysis_tools.py
import posixpath
import re
from typing import Any, Literal
from contextlib import asynccontextmanager
from collections.abc import AsyncIterator
import asyncssh
from pydantic import BaseModel, Field
# 从 MCP SDK 导入核心组件 [1]
from mcp.server.fastmcp import FastMCP, Context
from mcp.server.session import ServerSession
# 从 lifespan.py 导入共享的配置和数据类 [3]
# 注意: 我们不再需要导入 system_lifespan
from lifespan import (
SharedAppContext,
REMOTE_HOST,
REMOTE_USER,
PRIVATE_KEY_PATH,
INITIAL_WORKING_DIRECTORY,
)
# ==============================================================================
# 1. 辅助函数 (为实现独立性而复制)
# ==============================================================================
def shell_quote(arg: str) -> str:
"""安全地将字符串转义为单个 shell 参数 [5]。"""
return "'" + arg.replace("'", "'\"'\"'") + "'"
def _safe_join(sandbox_root: str, relative_path: str) -> str:
"""安全地将用户提供的相对路径连接到沙箱根目录 [5]。"""
rel = (relative_path or ".").strip()
if rel.startswith("/"):
rel = rel.lstrip("/")
combined = posixpath.normpath(posixpath.join(sandbox_root, rel))
root_norm = sandbox_root.rstrip("/")
if combined != root_norm and not combined.startswith(root_norm + "/"):
raise ValueError("路径越界: 仅允许访问沙箱目录内部")
if ".." in combined.split("/"):
raise ValueError("非法路径: 不允许使用 '..' 跨目录")
return combined
# ==============================================================================
# 2. 独立的生命周期管理器
# ==============================================================================
@asynccontextmanager
async def analysis_lifespan(_server: FastMCP) -> AsyncIterator[SharedAppContext]:
"""
为拓扑分析工具独立管理 SSH 连接的生命周期。
"""
conn: asyncssh.SSHClientConnection | None = None
print("分析工具: 正在建立独立的 SSH 连接...")
try:
conn = await asyncssh.connect(
REMOTE_HOST,
username=REMOTE_USER,
client_keys=[PRIVATE_KEY_PATH],
)
print(f"分析工具: 独立的 SSH 连接到 {REMOTE_HOST} 成功!")
yield SharedAppContext(ssh_connection=conn, sandbox_path=INITIAL_WORKING_DIRECTORY)
finally:
if conn:
conn.close()
await conn.wait_closed()
print("分析工具: 独立的 SSH 连接已关闭。")
# ==============================================================================
# 3. 数据模型与解析函数 (与之前相同)
# ==============================================================================
# (为简洁起见,此处省略了 Pydantic 模型和 _parse_analysis_output 函数的定义,
# 请将上一版本中的这部分代码粘贴到这里)
VALID_CONFIGS = Literal["O.yaml", "S.yaml", "Cl.yaml", "Br.yaml"]
class SubmitTopologicalAnalysisParams(BaseModel):
cif_file_path: str = Field(description="远程服务器上 CIF 文件的相对路径。")
config_files: list[VALID_CONFIGS] = Field(description="选择一个或多个用于分析的 YAML 配置文件。")
output_file_path: str = Field(description="用于保存计算结果的输出文件相对路径。")
class AnalysisResult(BaseModel):
percolation_diameter_a: float | None = Field(None, description="渗透直径 (Percolation diameter), 单位 Å。")
connectivity_distance: float | None = Field(None, description="连通距离 (the minium of d), 单位 Å。")
max_node_length_a: float | None = Field(None, description="最大节点长度 (Maximum node length), 单位 Å。")
min_cation_distance_a: float | None = Field(None, description="到阳离子的最小距离, 单位 Å。")
total_time_s: float | None = Field(None, description="总计算耗时, 单位秒。")
long_node_count: int | None = Field(None, description="长节点数量 (Long node number)。")
short_node_count: int | None = Field(None, description="短节点数量 (Short node number)。")
warnings: list[str] = Field(default_factory=list, description="解析过程中遇到的警告或缺失的关键值信息。")
def _parse_analysis_output(log_content: str) -> AnalysisResult:
def find_float(pattern: str) -> float | None:
match = re.search(pattern, log_content)
return float(match.group(1)) if match else None
def find_int(pattern: str) -> int | None:
match = re.search(pattern, log_content)
return int(match.group(1)) if match else None
data = { "percolation_diameter_a": find_float(r"Percolation diameter \(A\):\s*([\d\.]+)"), "max_node_length_a": find_float(r"Maximum node length detected:\s*([\d\.]+)\s*A"), "min_cation_distance_a": find_float(r"minimum distance to cations is\s*([\d\.]+)\s*A"), "total_time_s": find_float(r"Total used time:\s*([\d\.]+)"), "long_node_count": find_int(r"Long node number:\s*(\d+)"), "short_node_count": find_int(r"Short node number:\s*(\d+)"), }
conn_dist_match = re.search(r"the minium of d\s*([\d\.]+)", log_content, re.MULTILINE)
data["connectivity_distance"] = float(conn_dist_match.group(1)) if conn_dist_match else None
warnings = [k + " 未找到" for k, v in data.items() if v is None and k not in ["min_cation_distance_a", "long_node_count", "short_node_count"]]
return AnalysisResult(**data, warnings=warnings)
# ==============================================================================
# 4. MCP 服务器工厂函数
# ==============================================================================
def create_topological_analysis_mcp() -> FastMCP:
"""
供 Starlette 挂载的工厂函数。
这个函数现在创建并使用自己独立的 lifespan 管理器。
"""
analysis_mcp = FastMCP(
name="Topological Analysis Tools",
instructions="用于在远程服务器上提交和分析拓扑计算任务的工具集。",
streamable_http_path="/",
lifespan=analysis_lifespan, # 关键: 使用本文件内定义的独立 lifespan
)
@analysis_mcp.tool()
async def submit_topological_analysis(
params: SubmitTopologicalAnalysisParams,
ctx: Context[ServerSession, SharedAppContext],
) -> dict[str, Any]:
"""【步骤1/2】异步提交拓扑分析任务。"""
# ... (工具的实现逻辑与之前完全相同)
try:
app_ctx = ctx.request_context.lifespan_context
conn = app_ctx.ssh_connection
sandbox_root = app_ctx.sandbox_path
home_dir = f"/cluster/home/{REMOTE_USER}"
tool_dir = f"{home_dir}/tool/Topological_analyse"
cif_abs_path = _safe_join(sandbox_root, params.cif_file_path)
output_abs_path = _safe_join(sandbox_root, params.output_file_path)
config_args = " ".join([shell_quote(f"{tool_dir}/{cfg}") for cfg in params.config_files])
config_flags = f"-i {config_args}"
command = f"""
nohup sh -c '
source {shell_quote(f"{home_dir}/.bashrc")} && \\
conda activate {shell_quote(f"{home_dir}/anaconda3/envs/zeo")} && \\
python {shell_quote(f"{tool_dir}/analyze_voronoi_nodes.py")} \\
{shell_quote(cif_abs_path)} \\
{config_flags}
' > {shell_quote(output_abs_path)} 2>&1 &
""".strip()
await ctx.info("提交后台拓扑分析任务...")
await conn.run(command, check=True)
return {"success": True, "message": "拓扑分析任务已成功提交。", "output_file": params.output_file_path}
except Exception as e:
msg = f"提交后台任务失败: {e}"
await ctx.error(msg)
return {"success": False, "error": msg}
@analysis_mcp.tool()
async def analyze_topological_results(
output_file_path: str,
ctx: Context[ServerSession, SharedAppContext],
) -> AnalysisResult:
"""【步骤2/2】读取并分析拓扑分析任务的输出文件。"""
# ... (工具的实现逻辑与之前完全相同)
try:
await ctx.info(f"开始分析结果文件: {output_file_path}")
app_ctx = ctx.request_context.lifespan_context
conn = app_ctx.ssh_connection
sandbox_root = app_ctx.sandbox_path
target_path = _safe_join(sandbox_root, output_file_path)
async with conn.start_sftp_client() as sftp:
async with sftp.open(target_path, "r") as f:
log_content = await f.read()
if not log_content:
raise ValueError("结果文件为空。")
analysis_data = _parse_analysis_output(log_content)
return analysis_data
except FileNotFoundError:
msg = f"分析失败: 结果文件 '{output_file_path}' 不存在。"
await ctx.error(msg)
return AnalysisResult(warnings=[msg])
except Exception as e:
msg = f"分析结果时发生错误: {e}"
await ctx.error(msg)
return AnalysisResult(warnings=[msg])
return analysis_mcp

10
mcp/uvicorn/main.py Normal file
View File

@@ -0,0 +1,10 @@
# main.py
from fastapi import FastAPI
app = FastAPI()
@app.get("/")
async def read_root():
return {"message": "Hello, World!"}

19
softBV/out/Br.csv Normal file
View File

@@ -0,0 +1,19 @@
filename,formula,conductivity(e-3S/m)
1080005.cif,Cs3Li2Br5,21.23
1147619.cif,Li3YBr6,20.00
1147621.cif,Li3InBr6,7.23
1211043.cif,LiFeBr4,11.60
1222492.cif,Li3ErBr6,29.61
1222679.cif,Li2MnBr4,9.68
2033990.cif,CsLi2Br3,9.71
22967.cif,Li2FeBr4,11.21
23057.cif,CsLiBr2,5.12
2763849.cif,Cs3Li2Br5,121.34
28237.cif,RbLiBr2,3.46
28250.cif,Li2MnBr4,12.25
28326.cif,LiGaBr4,1.56
28327.cif,LiGaBr3,8.32
28829.cif,Li2ZnBr4,15.71
37873.cif,Li3ErBr6,10.90
580554.cif,CsLi3Br4,38.40
606680.cif,CsLi2Br3,20.63
1 filename formula conductivity(e-3S/m)
2 1080005.cif Cs3Li2Br5 21.23
3 1147619.cif Li3YBr6 20.00
4 1147621.cif Li3InBr6 7.23
5 1211043.cif LiFeBr4 11.60
6 1222492.cif Li3ErBr6 29.61
7 1222679.cif Li2MnBr4 9.68
8 2033990.cif CsLi2Br3 9.71
9 22967.cif Li2FeBr4 11.21
10 23057.cif CsLiBr2 5.12
11 2763849.cif Cs3Li2Br5 121.34
12 28237.cif RbLiBr2 3.46
13 28250.cif Li2MnBr4 12.25
14 28326.cif LiGaBr4 1.56
15 28327.cif LiGaBr3 8.32
16 28829.cif Li2ZnBr4 15.71
17 37873.cif Li3ErBr6 10.90
18 580554.cif CsLi3Br4 38.40
19 606680.cif CsLi2Br3 20.63

165
softBV/run_softBV.py Normal file
View File

@@ -0,0 +1,165 @@
import os
import sys
import subprocess
import re
import csv
import argparse
from pathlib import Path
# --- 配置区 ---
# softBV.x 可执行文件的路径,请根据你的实际情况修改
SOFTBV_EXECUTABLE = Path.home() / "tool/softBV-GUI_linux/bin/softBV.x"
# 固定的命令参数
MOBILE_ION = "Li"
ION_VALENCE = "1"
# 输出CSV文件的名称
OUTPUT_CSV_FILE = "conductivity_results.csv"
# --- 配置区结束 ---
def check_executable():
"""检查 softBV.x 文件是否存在且可执行"""
if not SOFTBV_EXECUTABLE.is_file():
print(f"错误: 可执行文件未找到: {SOFTBV_EXECUTABLE}")
print("请检查脚本中的 SOFTBV_EXECUTABLE 路径是否正确。")
sys.exit(1)
if not os.access(SOFTBV_EXECUTABLE, os.X_OK):
print(f"错误: 文件存在但不可执行: {SOFTBV_EXECUTABLE}")
print("请使用 'chmod +x' 命令赋予其执行权限。")
sys.exit(1)
def get_formula_from_cif_line2(cif_path):
"""
通过读取CIF文件的第二行来提取化学式。
期望的格式是 'data_FORMULA'
"""
try:
with open(cif_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
if len(lines) > 1:
second_line = lines[1].strip() # .strip() 去除首尾空白和换行符
if second_line.lower().startswith('data_'):
# 提取 'data_' 之后的内容作为化学式
return second_line[5:]
else:
print(f"\n警告: 文件 {cif_path.name} 的第二行格式不正确 (不是以'data_'开头)。", file=sys.stderr)
return "FormatError"
else:
print(f"\n警告: 文件 {cif_path.name} 行数不足2行。", file=sys.stderr)
return "FileTooShort"
except Exception as e:
print(f"\n警告: 读取文件 {cif_path.name} 时发生错误: {e}", file=sys.stderr)
return "ReadError"
def parse_conductivity(output_text):
"""从 softBV.x 的输出文本中解析并格式化电导率"""
pattern = re.compile(r"^\s*MD: conductivity\s*=\s*([-\d.eE+]+)", re.MULTILINE)
match = pattern.search(output_text)
if match:
conductivity_str = match.group(1)
try:
conductivity_val = float(conductivity_str)
# 将单位从 S/m 转换为 1e-3 S/m (mS/m), 需要乘以 1000
conductivity_ms_m = conductivity_val * 1000
# 格式化为两位小数的字符串
return f"{conductivity_ms_m:.2f}"
except ValueError:
return "ValueError"
else:
return None
def main():
"""主函数"""
parser = argparse.ArgumentParser(
description="运行 softBV.x 计算电导率并汇总到CSV文件。",
formatter_class=argparse.RawTextHelpFormatter
)
parser.add_argument("folder", type=str, help="包含CIF文件的目标文件夹路径。")
parser.add_argument(
"-t", "--temperature",
type=int,
default=1000,
help="模拟温度 (K),对应 softBV.x 的最后一个参数。\n默认值: 1000"
)
args = parser.parse_args()
target_folder = Path(args.folder)
temperature = str(args.temperature)
if not target_folder.is_dir():
print(f"错误: 文件夹 '{target_folder}' 不存在。")
sys.exit(1)
check_executable()
cif_files = sorted(list(target_folder.glob("*.cif")))
if not cif_files:
print(f"警告: 在文件夹 '{target_folder}' 中没有找到任何 .cif 文件。")
return
print(f"'{target_folder}' 中找到 {len(cif_files)} 个 .cif 文件,开始处理...")
print(f"模拟温度设置为: {temperature} K")
results = []
for cif_file in cif_files:
print(f" -> 正在处理: {cif_file} ...", end="", flush=True)
# 使用新的函数获取化学式
formula = get_formula_from_cif_line2(cif_file)
command = [
str(SOFTBV_EXECUTABLE), "--md", str(cif_file),
MOBILE_ION, ION_VALENCE, temperature
]
try:
process_result = subprocess.run(
command, capture_output=True, text=True, check=True, timeout=300
)
conductivity = parse_conductivity(process_result.stdout)
if conductivity is not None:
results.append([cif_file.name, formula, conductivity])
print(f" 成功, Formula: {formula}, Conductivity: {conductivity} (x10^-3 S/m)")
else:
print(" 失败 (无法在输出中找到conductivity)")
results.append([cif_file.name, formula, "NotFound"])
except subprocess.CalledProcessError as e:
print(f" 失败 (命令执行错误)")
print(f" 错误信息: {e.stderr.strip()}", file=sys.stderr)
results.append([cif_file.name, formula, "ExecError"])
except subprocess.TimeoutExpired:
print(f" 失败 (命令执行超时)")
results.append([cif_file.name, formula, "Timeout"])
except Exception as e:
print(f" 失败 (发生未知错误: {e})", file=sys.stderr)
results.append([cif_file.name, formula, "UnknownError"])
if not results:
print("没有成功处理任何文件不生成CSV。")
return
print(f"\n处理完成,正在将 {len(results)} 条结果写入 {OUTPUT_CSV_FILE} ...")
try:
with open(OUTPUT_CSV_FILE, 'w', newline='', encoding='utf-8') as csvfile:
writer = csv.writer(csvfile)
writer.writerow(['filename', 'formula', 'conductivity(e-3S/m)'])
writer.writerows(results)
print("CSV文件已成功生成")
except IOError as e:
print(f"错误: 无法写入CSV文件: {e}")
if __name__ == "__main__":
main()