Compare commits

...

13 Commits

Author SHA1 Message Date
95d719cc1e 对比学习法增改 2025-10-29 11:39:30 +08:00
1f8667ae51 新增了zeo++的部分 2025-10-23 17:08:57 +08:00
c0b2ec5983 sofvBV_mcp重构v2
Embedding copy
2025-10-22 23:59:23 +08:00
b9ba79d7a8 sofvBV_mcp重构 2025-10-22 16:52:56 +08:00
bd4bd3a645 LiYCl
朱一舟
2025-10-16 16:49:33 +08:00
c20d752faa LiYCl
朱一舟
2025-10-16 16:22:13 +08:00
4e21954471 mcp-softBV 2025-10-16 14:57:27 +08:00
fc7716507a mcp-system-fix 2025-10-15 16:19:41 +08:00
b11bf2417b mcp-gen-material 2025-10-15 14:43:57 +08:00
9b0f835575 mcp
成功使用安全版本实现系统相关操作
2025-10-14 09:52:02 +08:00
b19352382b mcp-python 2025-10-13 15:37:31 +08:00
28639f9cbf Merge remote-tracking branch 'origin/master' 2025-10-13 13:09:22 +08:00
41a6038e50 mcp-uvicorn 2025-10-13 13:09:06 +08:00
29 changed files with 4707 additions and 2 deletions

2
.idea/misc.xml generated
View File

@@ -3,5 +3,5 @@
<component name="Black">
<option name="sdkName" value="Python 3.12" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.12 (test1)" project-jdk-type="Python SDK" />
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.11" project-jdk-type="Python SDK" />
</project>

View File

@@ -2,7 +2,7 @@
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="Python 3.12 (test1)" jdkType="Python SDK" />
<orderEntry type="jdk" jdkName="Python 3.11" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>

151
contrast learning/copy.py Normal file
View File

@@ -0,0 +1,151 @@
import shutil
from pathlib import Path
def find_element_column_index(cif_lines: list) -> int:
"""
在CIF文件内容中查找 _atom_site_type_symbol 所在的列索引。
:param cif_lines: 从CIF文件读取的行列表。
:return: 元素符号列的索引从0开始如果未找到则返回-1。
"""
in_loop_header = False
column_index = -1
current_column = 0
for line in cif_lines:
line_stripped = line.strip()
if not line_stripped:
continue
if line_stripped.startswith('loop_'):
in_loop_header = True
column_index = -1
current_column = 0
continue
if in_loop_header:
if line_stripped.startswith('_'):
if line_stripped.startswith('_atom_site_type_symbol'):
column_index = current_column
current_column += 1
else:
# loop_ 头部定义结束,开始数据行
return column_index
return -1 # 如果文件中没有找到 loop_ 或 _atom_site_type_symbol
def copy_cif_with_O_or_S_robust(source_dir: str, target_dir: str, dry_run: bool = False):
"""
从源文件夹中筛选出内容包含'O''S'元素的CIF文件并复制到目标文件夹。
(鲁棒版能正确解析CIF中的元素符号列)
:param source_dir: 源文件夹路径包含CIF文件。
:param target_dir: 目标文件夹路径,用于存放筛选出的文件。
:param dry_run: 如果为True则只打印将要复制的文件而不实际执行复制操作。
"""
# 1. 路径处理和验证
source_path = Path(source_dir)
target_path = Path(target_dir)
if not source_path.is_dir():
print(f"错误:源文件夹 '{source_dir}' 不存在或不是一个文件夹。")
return
if not dry_run and not target_path.exists():
target_path.mkdir(parents=True, exist_ok=True)
print(f"目标文件夹 '{target_dir}' 已创建。")
print(f"源文件夹: {source_path}")
print(f"目标文件夹: {target_path}")
if dry_run:
print("\n--- *** 模拟运行模式 (Dry Run) *** ---")
print("--- 不会执行任何实际的文件复制操作 ---")
# 2. 开始遍历和筛选
print("\n开始扫描源文件夹中的CIF文件...")
copied_count = 0
checked_files = 0
error_files = 0
# 使用 rglob('*.cif') 可以遍历所有子文件夹,如果只想遍历当前文件夹用 glob
for file_path in source_path.glob('*.cif'):
if file_path.is_file():
checked_files += 1
try:
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
lines = f.readlines()
# 步骤 A: 找到元素符号在哪一列
element_col_idx = find_element_column_index(lines)
if element_col_idx == -1:
# 在某些CIF文件中可能没有loop块而是简单的 key-value 格式
# 为了兼容这种情况,我们保留一个简化的检查
found_simple = any(
line.strip().startswith(('_chemical_formula_sum', '_chemical_formula_structural')) and (
' O' in line or ' S' in line) for line in lines)
if not found_simple:
continue # 如果两种方法都找不到,跳过此文件
# 步骤 B: 检查该列是否有 'O' 或 'S'
found = False
for line in lines:
line_stripped = line.strip()
# 忽略空行、注释行和定义行
if not line_stripped or line_stripped.startswith(('#', '_', 'loop_')):
continue
parts = line_stripped.split()
# 确保行中有足够的列
if len(parts) > element_col_idx:
# 元素符号可能带有电荷,如 O2-,所以用 startswith
atom_symbol = parts[element_col_idx].strip()
if atom_symbol == 'O' or atom_symbol == 'S':
found = True
break
# 兼容性检查:如果通过了 found_simple 的检查,也标记为找到
if found_simple:
found = True
if found:
target_file_path = target_path / file_path.name
print(f"找到匹配: '{file_path.name}' (含有 O 或 S 元素)")
if not dry_run:
shutil.copy2(file_path, target_file_path)
# print(f" -> 已复制到 {target_file_path}") # 可以取消注释以获得更详细的输出
copied_count += 1
except Exception as e:
error_files += 1
print(f"!! 处理文件 '{file_path.name}' 时发生错误: {e}")
# 3. 打印最终报告
print("\n--- 操作总结 ---")
print(f"共检查了 {checked_files} 个.cif文件。")
if error_files > 0:
print(f"处理过程中有 {error_files} 个文件发生错误。")
if dry_run:
print(f"模拟运行结束:如果实际运行,将会有 {copied_count} 个文件被复制。")
else:
print(f"成功复制了 {copied_count} 个文件到目标文件夹。")
if __name__ == '__main__':
# !! 重要:请将下面的路径修改为您自己电脑上的实际路径
source_folder = "D:/download/2025-10/data_all/input/input"
target_folder = "D:/download/2025-10/data_all/output"
# --- 第一次运行:使用模拟模式 (Dry Run) ---
print("================ 第一次运行: 模拟模式 ================")
copy_cif_with_O_or_S_robust(source_folder, target_folder, dry_run=True)
print("\n\n=======================================================")
input("检查上面的模拟运行结果。如果符合预期,按回车键继续执行实际复制操作...")
print("=======================================================")
# --- 第二次运行:实际执行复制 ---
print("\n================ 第二次运行: 实际复制模式 ================")
copy_cif_with_O_or_S_robust(source_folder, target_folder, dry_run=False)

111
contrast learning/delete.py Normal file
View File

@@ -0,0 +1,111 @@
import shutil
from pathlib import Path
def delete_duplicates_from_second_folder(source_dir: str, target_dir: str, dry_run: bool = False):
"""
删除第二个文件夹中与第一个文件夹内项目同名的文件或文件夹。
:param source_dir: 第一个文件夹(源)的路径。
:param target_dir: 第二个文件夹(目标)的路径,将从此文件夹中删除内容。
:param dry_run: 如果为True则只打印将要删除的内容而不实际执行删除操作。
"""
# 1. 将字符串路径转换为Path对象方便操作
source_path = Path(source_dir)
target_path = Path(target_dir)
# 2. 验证路径是否存在且为文件夹
if not source_path.is_dir():
print(f"错误:源文件夹 '{source_dir}' 不存在或不是一个文件夹。")
return
if not target_path.is_dir():
print(f"错误:目标文件夹 '{target_dir}' 不存在或不是一个文件夹。")
return
print(f"源文件夹: {source_path}")
print(f"目标文件夹: {target_path}")
if dry_run:
print("\n--- *** 模拟运行模式 (Dry Run) *** ---")
print("--- 不会执行任何实际的删除操作 ---")
# 3. 获取源文件夹中所有项目(文件和子文件夹)的名称
# p.name 会返回路径的最后一部分,即文件名或文件夹名
source_item_names = {p.name for p in source_path.iterdir()}
if not source_item_names:
print("\n源文件夹为空,无需执行任何操作。")
return
print(f"\n在源文件夹中找到 {len(source_item_names)} 个项目。")
print("开始检查并删除目标文件夹中的同名项目...")
deleted_count = 0
# 4. 遍历源文件夹中的项目名称
for item_name in source_item_names:
# 构建目标文件夹中可能存在的同名项目的完整路径
item_to_delete = target_path / item_name
# 5. 检查该项目是否存在于目标文件夹中
if item_to_delete.exists():
try:
if item_to_delete.is_file():
# 如果是文件,直接删除
print(f"准备删除文件: {item_to_delete}")
if not dry_run:
item_to_delete.unlink()
print(" -> 已删除。")
deleted_count += 1
elif item_to_delete.is_dir():
# 如果是文件夹,使用 shutil.rmtree 删除整个文件夹及其内容
print(f"准备删除文件夹及其所有内容: {item_to_delete}")
if not dry_run:
shutil.rmtree(item_to_delete)
print(" -> 已删除。")
deleted_count += 1
except Exception as e:
print(f"!! 删除 '{item_to_delete}' 时发生错误: {e}")
if deleted_count == 0:
print("\n操作完成:在目标文件夹中没有找到需要删除的同名项目。")
else:
if dry_run:
print(f"\n模拟运行结束:如果实际运行,将会有 {deleted_count} 个项目被删除。")
else:
print(f"\n操作完成:总共删除了 {deleted_count} 个项目。")
# --- 使用示例 ---
# 在运行前,请创建以下文件夹和文件结构进行测试:
# /your/path/folder1/
# ├── file_a.txt
# ├── file_b.log
# └── subfolder_x/
# └── test.txt
# /your/path/folder2/
# ├── file_a.txt (将被删除)
# ├── file_c.md
# └── subfolder_x/ (将被删除)
# └── another.txt
if __name__ == '__main__':
# !! 重要:请将下面的路径修改为您自己电脑上的实际路径
folder1_path = "D:/download/2025-10/after_step5/after_step5/S" # 源文件夹
folder2_path = "D:/download/2025-10/input/input" # 目标文件夹
# --- 第一次运行:使用模拟模式 (Dry Run),非常推荐!---
# 这会告诉你脚本将要做什么,但不会真的删除任何东西。
print("================ 第一次运行: 模拟模式 ================")
delete_duplicates_from_second_folder(folder1_path, folder2_path, dry_run=True)
print("\n\n=======================================================")
input("检查上面的模拟运行结果。如果符合预期,按回车键继续执行实际删除操作...")
print("=======================================================")
# --- 第二次运行:实际执行删除 ---
# 确认模拟运行结果无误后,再将 dry_run 设置为 False 或移除该参数。
print("\n================ 第二次运行: 实际删除模式 ================")
delete_duplicates_from_second_folder(folder1_path, folder2_path, dry_run=False)

View File

@@ -0,0 +1,50 @@
#==============================================================================
# CRYSTAL DATA
#------------------------------------------------------------------------------
# Created from data in Table S9 for ball milled and 1h annealed Li3YCl6
# Source: Image provided by user
#==============================================================================
data_Li3YCl6_annealed
_audit_creation_method 'Generated from published data table'
#------------------------------------------------------------------------------
# CELL PARAMETERS
#------------------------------------------------------------------------------
_cell_length_a 11.2001(2)
_cell_length_b 11.2001(2)
_cell_length_c 6.0441(2)
_cell_angle_alpha 90.0
_cell_angle_beta 90.0
_cell_angle_gamma 120.0
_cell_volume 656.39
#------------------------------------------------------------------------------
# SYMMETRY
#------------------------------------------------------------------------------
_symmetry_space_group_name_H-M 'P -3 m 1'
_symmetry_Int_Tables_number 164
#------------------------------------------------------------------------------
# ATOMIC COORDINATES AND DISPLACEMENT PARAMETERS
#------------------------------------------------------------------------------
loop_
_atom_site_label
_atom_site_type_symbol
_atom_site_Wyckoff_symbol
_atom_site_fract_x
_atom_site_fract_y
_atom_site_fract_z
_atom_site_occupancy
_atom_site_B_iso_or_equiv
Y1 Y 1a 0.0000 0.0000 0.0000 1.0000 1.10(3)
Y2 Y 2d 0.3333 0.6666 -0.065(3) 1.0000 1.10(3)
Cl1 Cl 6i 0.1131(7) -0.1131(7) 0.7717(8) 1.0000 1.99(3)
Cl2 Cl 6i 0.2182(7) -0.2182(7) 0.2606(8) 1.0000 1.99(3)
Cl3 Cl 6i 0.4436(4) -0.4436(4) 0.7604(8) 1.0000 1.99(3)
Li1 Li 6g 0.3397 0.3397 0.0000 1.0000 5.00
Li2 Li 6h 0.3397 0.3397 0.5000 0.5000 5.00
#==============================================================================
# END OF DATA
#==============================================================================

View File

@@ -0,0 +1,50 @@
#==============================================================================
# CRYSTAL DATA
#------------------------------------------------------------------------------
# Created from data in Table S9 for ball milled and 1h annealed Li3YCl6
# Source: Image provided by user
#==============================================================================
data_Li3YCl6_annealed
_audit_creation_method 'Generated from published data table'
#------------------------------------------------------------------------------
# CELL PARAMETERS
#------------------------------------------------------------------------------
_cell_length_a 11.2001(2)
_cell_length_b 11.2001(2)
_cell_length_c 6.0441(2)
_cell_angle_alpha 90.0
_cell_angle_beta 90.0
_cell_angle_gamma 120.0
_cell_volume 656.39
#------------------------------------------------------------------------------
# SYMMETRY
#------------------------------------------------------------------------------
_symmetry_space_group_name_H-M 'P -3 m 1'
_symmetry_Int_Tables_number 164
#------------------------------------------------------------------------------
# ATOMIC COORDINATES AND DISPLACEMENT PARAMETERS
#------------------------------------------------------------------------------
loop_
_atom_site_label
_atom_site_type_symbol
_atom_site_Wyckoff_symbol
_atom_site_fract_x
_atom_site_fract_y
_atom_site_fract_z
_atom_site_occupancy
_atom_site_B_iso_or_equiv
Y1 Y 1a 0.0000 0.0000 0.0000 1.0000 1.10(3)
Y2 Y 2d 0.3333 0.6666 0.488(1) 1.0000 1.10(3)
Cl1 Cl 6i 0.1131(7) -0.1131(7) 0.7717(8) 1.0000 1.99(3)
Cl2 Cl 6i 0.2182(7) -0.2182(7) 0.2606(8) 1.0000 1.99(3)
Cl3 Cl 6i 0.4436(4) -0.4436(4) 0.7604(8) 1.0000 1.99(3)
Li1 Li 6g 0.3397 0.3397 0.0000 1.0000 5.00
Li2 Li 6h 0.3397 0.3397 0.5000 0.5000 5.00
#==============================================================================
# END OF DATA
#==============================================================================

View File

@@ -0,0 +1,51 @@
#==============================================================================
# CRYSTAL DATA
#------------------------------------------------------------------------------
# Created from data in Table S9 for ball milled and 1h annealed Li3YCl6
# Source: Image provided by user
#==============================================================================
data_Li3YCl6_annealed
_audit_creation_method 'Generated from published data table'
#------------------------------------------------------------------------------
# CELL PARAMETERS
#------------------------------------------------------------------------------
_cell_length_a 11.2001(2)
_cell_length_b 11.2001(2)
_cell_length_c 6.0441(2)
_cell_angle_alpha 90.0
_cell_angle_beta 90.0
_cell_angle_gamma 120.0
_cell_volume 656.39
#------------------------------------------------------------------------------
# SYMMETRY
#------------------------------------------------------------------------------
_symmetry_space_group_name_H-M 'P -3 m 1'
_symmetry_Int_Tables_number 164
#------------------------------------------------------------------------------
# ATOMIC COORDINATES AND DISPLACEMENT PARAMETERS
#------------------------------------------------------------------------------
loop_
_atom_site_label
_atom_site_type_symbol
_atom_site_Wyckoff_symbol
_atom_site_fract_x
_atom_site_fract_y
_atom_site_fract_z
_atom_site_occupancy
_atom_site_B_iso_or_equiv
Y1 Y 1a 0.0000 0.0000 0.0000 1.0000 1.10(3)
Y2 Y 2d 0.3333 0.6666 0.488(1) 0.7500 1.10(3)
Y3 Y 2d 0.3333 0.6666 -0.065(3) 0.2500 1.10(3)
Cl1 Cl 6i 0.1131(7) -0.1131(7) 0.7717(8) 1.0000 1.99(3)
Cl2 Cl 6i 0.2182(7) -0.2182(7) 0.2606(8) 1.0000 1.99(3)
Cl3 Cl 6i 0.4436(4) -0.4436(4) 0.7604(8) 1.0000 1.99(3)
Li1 Li 6g 0.3397 0.3397 0.0000 1.0000 5.00
Li2 Li 6h 0.3397 0.3397 0.5000 0.5000 5.00
#==============================================================================
# END OF DATA
#==============================================================================

View File

@@ -0,0 +1,60 @@
#==============================================================================
# CRYSTAL DATA
#------------------------------------------------------------------------------
# Created from data in Table S9 for ball milled and 1h annealed Li3YCl6
# Source: Image provided by user
#==============================================================================
data_Li3YCl6_annealed
_audit_creation_method 'Generated from published data table'
#------------------------------------------------------------------------------
# CELL PARAMETERS
#------------------------------------------------------------------------------
_cell_length_a 11.2001(2)
_cell_length_b 11.2001(2)
_cell_length_c 6.0441(2)
_cell_angle_alpha 90.0
_cell_angle_beta 90.0
_cell_angle_gamma 120.0
_cell_volume 656.39
#------------------------------------------------------------------------------
# SYMMETRY
#------------------------------------------------------------------------------
_symmetry_space_group_name_H-M 'P -3 m 1'
_symmetry_Int_Tables_number 164
#------------------------------------------------------------------------------
# ATOMIC COORDINATES AND DISPLACEMENT PARAMETERS
#------------------------------------------------------------------------------
loop_
_atom_site_label
_atom_site_type_symbol
_atom_site_fract_x
_atom_site_fract_y
_atom_site_fract_z
_atom_site_occupancy
# --- Y位点 (基于Model 3, Occ: M1=1, M2=0.75, M3=0.25) ---
Y1_main Y 0.0000 0.0000 0.0000 0.9444 # 1a位
Li_on_Y1 Li 0.0000 0.0000 0.0000 0.0556
Y2_main Y 0.3333 0.6666 0.4880 0.7083 # 2d位
Li_on_Y2 Li 0.3333 0.6666 0.4880 0.0417
Y3_main Y 0.3333 0.6666 -0.0650 0.2361 # 2d位
Li_on_Y3 Li 0.3333 0.6666 -0.0650 0.0139
# --- Li位点 (基于标准结构, Occ: 6g=1, 6h=0.5) ---
Li1_main Li 0.3397 0.3397 0.0000 0.9815 # 6g位
Y_on_Li1 Y 0.3397 0.3397 0.0000 0.0185
Li2_main Li 0.3397 0.3397 0.5000 0.4907 # 6h位
Y_on_Li2 Y 0.3397 0.3397 0.5000 0.0093
# --- Cl位点 (保持不变) ---
Cl1 Cl 0.1131 -0.1131 0.7717 1.0000
Cl2 Cl 0.2182 -0.2182 0.2606 1.0000
Cl3 Cl 0.4436 -0.4436 0.7604 1.0000
#==============================================================================
# END OF DATA
#==============================================================================

View File

@@ -0,0 +1,51 @@
#==============================================================================
# CRYSTAL DATA
#------------------------------------------------------------------------------
# Created from data in Table S9 for ball milled and 1h annealed Li3YCl6
# Source: Image provided by user
#==============================================================================
data_Li3YCl6_annealed
_audit_creation_method 'Generated from published data table'
#------------------------------------------------------------------------------
# CELL PARAMETERS
#------------------------------------------------------------------------------
_cell_length_a 11.2001(2)
_cell_length_b 11.2001(2)
_cell_length_c 6.0441(2)
_cell_angle_alpha 90.0
_cell_angle_beta 90.0
_cell_angle_gamma 120.0
_cell_volume 656.39
#------------------------------------------------------------------------------
# SYMMETRY
#------------------------------------------------------------------------------
_symmetry_space_group_name_H-M 'P -3 m 1'
_symmetry_Int_Tables_number 164
#------------------------------------------------------------------------------
# ATOMIC COORDINATES AND DISPLACEMENT PARAMETERS
#------------------------------------------------------------------------------
loop_
_atom_site_label
_atom_site_type_symbol
_atom_site_Wyckoff_symbol
_atom_site_fract_x
_atom_site_fract_y
_atom_site_fract_z
_atom_site_occupancy
_atom_site_B_iso_or_equiv
Y1 Y 1a 0.0000 0.0000 0.0000 1.0000 1.10(3)
Y2 Y 2d 0.3333 0.6666 0.488(1) 0.8269 1.10(3)
Y3 Y 2d 0.3333 0.6666 -0.065(3) 0.1730 1.10(3)
Cl1 Cl 6i 0.1131(7) -0.1131(7) 0.7717(8) 1.0000 1.99(3)
Cl2 Cl 6i 0.2182(7) -0.2182(7) 0.2606(8) 1.0000 1.99(3)
Cl3 Cl 6i 0.4436(4) -0.4436(4) 0.7604(8) 1.0000 1.99(3)
Li1 Li 6g 0.3397 0.3397 0.0000 1.0000 5.00
Li2 Li 6h 0.3397 0.3397 0.5000 0.5000 5.00
#==============================================================================
# END OF DATA
#==============================================================================

View File

@@ -0,0 +1,53 @@
#------------------------------------------------------------------------------
# CIF (Crystallographic Information File) for Li3YCl6
# Data source: Table S1 from the provided image.
# Rietveld refinement result of the neutron diffraction pattern for the 450 °C-annealed sample.
#------------------------------------------------------------------------------
data_Li3YCl6
_chemical_name_systematic 'Lithium Yttrium Chloride'
_chemical_formula_sum 'Li3 Y1 Cl6'
_chemical_formula_structural 'Li3YCl6'
_symmetry_space_group_name_H-M 'P n m a'
_symmetry_Int_Tables_number 62
_symmetry_cell_setting orthorhombic
loop_
_symmetry_equiv_pos_as_xyz
'x, y, z'
'-x+1/2, y+1/2, -z+1/2'
'-x, -y, -z'
'x+1/2, -y+1/2, z+1/2'
'-x, y+1/2, -z'
'x-1/2, -y-1/2, z-1/2'
'x, -y, z'
'-x-1/2, y-1/2, -z-1/2'
_cell_length_a 12.92765(13)
_cell_length_b 11.19444(10)
_cell_length_c 6.04000(12)
_cell_angle_alpha 90.0
_cell_angle_beta 90.0
_cell_angle_gamma 90.0
_cell_volume 874.15
_cell_formula_units_Z 4
loop_
_atom_site_label
_atom_site_type_symbol
_atom_site_fract_x
_atom_site_fract_y
_atom_site_fract_z
_atom_site_occupancy
_atom_site_Wyckoff_symbol
_atom_site_U_iso_or_equiv
Li1 Li 0.11730(7) 0.09640(7) 0.04860(10) 0.750(13) 8d 4.579(2)
Li2 Li 0.13270(9) 0.07900(10) 0.48600(2) 0.750(19) 8d 9.554(4)
Cl1 Cl 0.21726(7) 0.58920(7) 0.26362(11) 1.0 8d 0.797(17)
Cl2 Cl 0.45948(8) 0.08259(8) 0.23831(13) 1.0 8d 1.548(2)
Cl3 Cl 0.04505(10) 0.25000 0.74110(2) 1.0 4c 1.848(3)
Cl4 Cl 0.20205(9) 0.25000 0.24970(2) 1.0 4c 0.561(2)
Y1 Y 0.37529(10) 0.25000 0.01870(3) 1.0 4c 1.121(17)
#------------------------------------------------------------------------------

151
dpgen/plus.py Normal file
View File

@@ -0,0 +1,151 @@
import random
from typing import List
from pymatgen.core import Structure
from pymatgen.io.vasp import Poscar
def _is_close_frac(z, target, tol=2e-2):
t = target % 1.0
return min(abs(z - t), abs(z - (t + 1)), abs(z - (t - 1))) < tol
def make_model3_poscar_from_cif(cif_path: str,
out_poscar: str = "POSCAR_model3_supercell",
seed: int = 42,
tol: float = 2e-2):
"""
将 model3.cif 扩胞为 [[3,0,0],[2,4,0],[0,0,6]] 的2160原子超胞并把部分占据位点(Y2=0.75, Y3=0.25, Li2=0.5)
显式有序化后写出 POSCAR。
"""
random.seed(seed)
# 1) 读取 CIF
s = Structure.from_file(cif_path)
# 2) 扩胞a_s=3a0, b_s=2a0+4b0, c_s=6c0[1]
T = [[3, 0, 0],
[2, 4, 0],
[0, 0, 6]]
s.make_supercell(T)
# 3) 识别三类需取整的位点Y2、Y3、Li2
y2_idx: List[int] = []
y3_idx: List[int] = []
li2_idx: List[int] = []
for i, site in enumerate(s.sites):
# 兼容不同版本pymatgen
try:
el = site.species.elements[0].symbol
except Exception:
ss = site.species_string
el = "Li" if ss.startswith("Li") else ("Y" if ss.startswith("Y") else ("Cl" if ss.startswith("Cl") else ss))
z = site.frac_coords[2]
if el == "Y":
if _is_close_frac(z, 0.488, tol):
y2_idx.append(i)
elif _is_close_frac(z, -0.065, tol) or _is_close_frac(z, 0.935, tol):
y3_idx.append(i)
elif el == "Li":
if _is_close_frac(z, 0.5, tol):
li2_idx.append(i)
def choose_keep(idxs, frac_keep):
n = len(idxs)
k = int(round(n * frac_keep))
if k < 0: k = 0
if k > n: k = n
keep = set(random.sample(idxs, k)) if 0 < k < n else set(idxs if k == n else [])
drop = [i for i in idxs if i not in keep]
return keep, drop
keep_y2, drop_y2 = choose_keep(y2_idx, 0.75)
keep_y3, drop_y3 = choose_keep(y3_idx, 0.25)
keep_li2, drop_li2 = choose_keep(li2_idx, 0.50)
# 4) 保留者占据设为1其余删除
for i in keep_y2 | keep_y3:
s.replace(i, "Y")
for i in keep_li2:
s.replace(i, "Li")
to_remove = sorted(drop_y2 + drop_y3 + drop_li2, reverse=True)
for i in to_remove:
s.remove_sites([i])
# 5) 最终清理:消除任何残留的部分占据(防止 POSCAR 写出报错)
# 若有 site.is_ordered==False则取该站位的“主要元素”替换为占据=1
for i, site in enumerate(s.sites):
if not site.is_ordered:
d = site.species.as_dict() # {'Li': 0.5} 或 {'Li':0.5,'Y':0.5}
elem = max(d.items(), key=lambda kv: kv[1])[0]
s.replace(i, elem)
# 6) 排序并写出 POSCAR
order = {"Li": 0, "Y": 1, "Cl": 2}
s = s.get_sorted_structure(key=lambda site: order.get(site.species.elements[0].symbol, 99))
Poscar(s).write_file(out_poscar)
# 报告
comp = {k: int(v) for k, v in s.composition.as_dict().items()}
print(f"写出 {out_poscar};总原子数 = {len(s)}")
print(f"Y2识别={len(y2_idx)}Y3识别={len(y3_idx)}Li2识别={len(li2_idx)};组成={comp}")
import random
from typing import List
from pymatgen.core import Structure
from pymatgen.io.vasp import Poscar
def make_pnma_poscar_from_cif(cif_path: str,
out_poscar: str = "POSCAR_pnma_supercell",
seed: int = 42,
supercell=(3,3,6),
tol: float = 1e-6):
"""
读取 Pnma 的 CIF如 origin.cif扩胞到 2160 原子,并把部分占据的 Li 位点(0.75)显式取整后写出 POSCAR。
默认超胞尺度为(3,3,6),体积放大因子=5440原子/原胞×54=2160 [1][3]。
"""
random.seed(seed)
s = Structure.from_file(cif_path)
# 扩胞Pnma原胞已是正交直接用对角放缩
s.make_supercell(supercell)
# 找出所有“部分占据的 Li”位点
partial_li_idx: List[int] = []
for i, site in enumerate(s.sites):
if not site.is_ordered:
d = site.species.as_dict() # 例如 {'Li': 0.75}
# 只处理主要元素是Li且占据<1的位点
m_elem, m_occ = max(d.items(), key=lambda kv: kv[1])
if m_elem == "Li" and m_occ < 1 - tol:
partial_li_idx.append(i)
# 以占据0.75进行随机取整保留75%,其余删除为“空位”
n = len(partial_li_idx)
k = int(round(n * 0.75))
keep = set(random.sample(partial_li_idx, k)) if 0 < k < n else set(partial_li_idx if k == n else [])
drop = sorted([i for i in partial_li_idx if i not in keep], reverse=True)
# 保留者设为占据=1删除其余
for i in keep:
s.replace(i, "Li")
for i in drop:
s.remove_sites([i])
# 兜底:若仍有部分占据,强制取主要元素
for i, site in enumerate(s.sites):
if not site.is_ordered:
d = site.species.as_dict()
elem = max(d.items(), key=lambda kv: kv[1])[0]
s.replace(i, elem)
# 排序并写POSCAR
order = {"Li": 0, "Y": 1, "Cl": 2}
s = s.get_sorted_structure(key=lambda site: order.get(site.species.elements[0].symbol, 99))
Poscar(s).write_file(out_poscar)
comp = {k: int(v) for k, v in s.composition.as_dict().items()}
print(f"写出 {out_poscar};总原子数 = {len(s)};组成 = {comp}")
if __name__=="__main__":
# make_model3_poscar_from_cif("data/P3ma/model3.cif","data/P3ma/supercell_model4.poscar")
make_pnma_poscar_from_cif("data/Pnma/origin.cif","data/Pnma/supercell_pnma.poscar",seed=42)

View File

76
dpgen/transport.py Normal file
View File

@@ -0,0 +1,76 @@
from typing import List, Dict, Tuple
def add_li_y_antisite_average(
x: float,
# Y 位点:[(label, wyckoff, (fx,fy,fz), occ, multiplicity)]
y_sites: List[Tuple[str, str, Tuple[float, float, float], float, int]] = None,
# Li 位点:[(label, wyckoff, (fx,fy,fz), occ, multiplicity)]
li_sites: List[Tuple[str, str, Tuple[float, float, float], float, int]] = None,
# 可选:各位点的 B 因子(字典),不需要可留空
b_iso: Dict[str, float] = None,
# 保留小数位
ndigits: int = 4
) -> str:
"""
把结构改写成含 x (相对于 Y 总数) 的 Li/Y 反位缺陷的“平均占据”CIF 块。
返回字符串,可直接粘贴到 CIF 的 atom_site loop 中。
"""
# 默认使用表 S9P-3̅m1的坐标与占据multiplicity 依据该空间群
if y_sites is None:
y_sites = [
("Y1", "1a", (0.0000, 0.0000, 0.0000), 1.0000, 1),
("Y2", "2d", (0.3333, 0.6666, 0.4880), 0.8269, 2),
("Y3", "2d", (0.3333, 0.6666,-0.0650), 0.1730, 2),
]
if li_sites is None:
li_sites = [
("Li1", "6g", (0.3397, 0.3397, 0.0000), 1.0000, 6),
("Li2", "6h", (0.3397, 0.3397, 0.5000), 0.5000, 6),
]
if b_iso is None:
# 可按需要改
b_iso = {"Y": 1.10, "Li": 5.00, "Cl": 1.99}
# 1) 总 Y 与总 Li每原胞的“原子数贡献”
y_total = sum(m * occ for _,_,_,occ,m in y_sites) # 期望为 3
li_total = sum(m * occ for _,_,_,occ,m in li_sites) # 期望为 9
# 2) 需要交换的“Y 数量”(相对于每原胞)
y_exchanged = x * y_total # 例如 x=0.0556 → ~0.1668 个 Y/原胞
# 3) 折算到 Li 子格的分配比例(保证电中性和计量)
li_fraction = y_exchanged / li_total # 每个 Li 位点应引入的 Y 的比例
# 4) 生成 CIF loop 文本(同一坐标写两行,分别为主元素与反位元素)
lines = []
header = [
"loop_",
" _atom_site_label",
" _atom_site_type_symbol",
" _atom_site_fract_x",
" _atom_site_fract_y",
" _atom_site_fract_z",
" _atom_site_occupancy",
" _atom_site_B_iso_or_equiv",
]
lines.extend(header)
# 4a) 处理 Y 位点Y → (1-x)*occLi_on_Y → x*occ
for label, wyck, (fx,fy,fz), occ, mult in y_sites:
y_main_occ = round(occ * (1.0 - x), ndigits)
li_on_y_occ = round(occ * x, ndigits)
lines.append(f" {label}_main Y {fx:.4f} {fy:.4f} {fz:.4f} {y_main_occ:.{ndigits}f} {b_iso.get('Y',1.0)}")
lines.append(f" Li_on_{label} Li {fx:.4f} {fy:.4f} {fz:.4f} {li_on_y_occ:.{ndigits}f} {b_iso.get('Li',5.0)}")
# 4b) 处理 Li 位点Li → (1-li_fraction)*occY_on_Li → li_fraction*occ
for label, wyck, (fx,fy,fz), occ, mult in li_sites:
li_main_occ = round(occ * (1.0 - li_fraction), ndigits)
y_on_li_occ = round(occ * li_fraction, ndigits)
lines.append(f" {label}_main Li {fx:.4f} {fy:.4f} {fz:.4f} {li_main_occ:.{ndigits}f} {b_iso.get('Li',5.0)}")
lines.append(f" Y_on_{label} Y {fx:.4f} {fy:.4f} {fz:.4f} {y_on_li_occ:.{ndigits}f} {b_iso.get('Y',1.0)}")
return "\n".join(lines)
# === 示例 ===
# x=0.0556(与论文使用的 5.56% 一致,定义为相对于 Y 总数的比例):
print(add_li_y_antisite_average(0.0556))

View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2025 gyj155
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -0,0 +1,96 @@
# Paper Semantic Search
Find similar papers using semantic search. Supports both local models (free) and OpenAI API (better quality).
## Features
- Request for papers from OpenReview (e.g., ICLR2026 submissions)
- Semantic search with example papers or text queries
- Support embedding caching
- Embed model support: Open-source (e.g., all-MiniLM-L6-v2) or OpenAI
## Quick Start
```bash
pip install -r requirements.txt
```
### 1. Prepare Papers
```python
from crawl import crawl_papers
crawl_papers(
venue_id="ICLR.cc/2026/Conference/Submission",
output_file="iclr2026_papers.json"
)
```
### 2. Search Papers
```python
from search import PaperSearcher
# Local model (free)
searcher = PaperSearcher('iclr2026_papers.json', model_type='local')
# OpenAI model (better, requires API key)
# export OPENAI_API_KEY='your-key'
# searcher = PaperSearcher('iclr2026_papers.json', model_type='openai')
searcher.compute_embeddings()
# Search with example papers that you are interested in
examples = [
{
"title": "Your paper title",
"abstract": "Your paper abstract..."
}
]
results = searcher.search(examples=examples, top_k=100)
# Or search with text query
results = searcher.search(query="interesting topics", top_k=100)
searcher.display(results, n=10)
searcher.save(results, 'results.json')
```
## How It Works
1. Paper titles and abstracts are converted to embeddings
2. Embeddings are cached automatically
3. Your query is embedded using the same model
4. Cosine similarity finds the most similar papers
5. Results are ranked by similarity score
## Cache
Embeddings are cached as `cache_<filename>_<hash>_<model>.npy`. Delete to recompute.
## Example Output
```
================================================================================
Top 100 Results (showing 10)
================================================================================
1. [0.8456] Paper a
#12345 | foundation or frontier models, including LLMs
https://openreview.net/forum?id=xxx
2. [0.8234] Paper b
#12346 | applications to robotics, autonomy, planning
https://openreview.net/forum?id=yyy
```
## Tips
- Use 1-5 example papers for best results, or a paragraph of description of your interested topic
- Local model is good enough for most cases
- OpenAI model for critical search (~$1 for 18k queries)
If it's useful, please consider giving a star~

View File

@@ -0,0 +1,66 @@
import requests
import json
import time
def fetch_submissions(venue_id, offset=0, limit=1000):
url = "https://api2.openreview.net/notes"
params = {
"content.venueid": venue_id,
"details": "replyCount,invitation",
"limit": limit,
"offset": offset,
"sort": "number:desc"
}
headers = {"User-Agent": "Mozilla/5.0"}
response = requests.get(url, params=params, headers=headers)
response.raise_for_status()
return response.json()
def crawl_papers(venue_id, output_file):
all_papers = []
offset = 0
limit = 1000
print(f"Fetching papers from {venue_id}...")
while True:
data = fetch_submissions(venue_id, offset, limit)
notes = data.get("notes", [])
if not notes:
break
for note in notes:
paper = {
"id": note.get("id"),
"number": note.get("number"),
"title": note.get("content", {}).get("title", {}).get("value", ""),
"authors": note.get("content", {}).get("authors", {}).get("value", []),
"abstract": note.get("content", {}).get("abstract", {}).get("value", ""),
"keywords": note.get("content", {}).get("keywords", {}).get("value", []),
"primary_area": note.get("content", {}).get("primary_area", {}).get("value", ""),
"forum_url": f"https://openreview.net/forum?id={note.get('id')}"
}
all_papers.append(paper)
print(f"Fetched {len(notes)} papers (total: {len(all_papers)})")
if len(notes) < limit:
break
offset += limit
time.sleep(0.5)
with open(output_file, "w", encoding="utf-8") as f:
json.dump(all_papers, f, ensure_ascii=False, indent=2)
print(f"\nTotal: {len(all_papers)} papers")
print(f"Saved to {output_file}")
return all_papers
if __name__ == "__main__":
crawl_papers(
venue_id="ICLR.cc/2026/Conference/Submission",
output_file="iclr2026_papers.json"
)

View File

@@ -0,0 +1,22 @@
from search import PaperSearcher
# Use local model (free)
searcher = PaperSearcher('iclr2026_papers.json', model_type='local')
# Or use OpenAI (better quality)
# searcher = PaperSearcher('iclr2026_papers.json', model_type='openai')
searcher.compute_embeddings()
examples = [
{
"title": "Solid-State battery",
"abstract": "Solid-State battery"
},
]
results = searcher.search(examples=examples, top_k=100)
searcher.display(results, n=10)
searcher.save(results, 'results.json')

View File

@@ -0,0 +1,6 @@
requests
numpy
scikit-learn
sentence-transformers
openai

View File

@@ -0,0 +1,156 @@
import json
import numpy as np
import os
import hashlib
from pathlib import Path
from sklearn.metrics.pairwise import cosine_similarity
class PaperSearcher:
def __init__(self, papers_file, model_type="openai", api_key=None, base_url=None):
with open(papers_file, 'r', encoding='utf-8') as f:
self.papers = json.load(f)
self.model_type = model_type
self.cache_file = self._get_cache_file(papers_file, model_type)
self.embeddings = None
if model_type == "openai":
from openai import OpenAI
self.client = OpenAI(
api_key=api_key or os.getenv('OPENAI_API_KEY'),
base_url=base_url
)
self.model_name = "text-embedding-3-large"
else:
from sentence_transformers import SentenceTransformer
self.model = SentenceTransformer('all-MiniLM-L6-v2')
self.model_name = "all-MiniLM-L6-v2"
self._load_cache()
def _get_cache_file(self, papers_file, model_type):
base_name = Path(papers_file).stem
file_hash = hashlib.md5(papers_file.encode()).hexdigest()[:8]
cache_name = f"cache_{base_name}_{file_hash}_{model_type}.npy"
return str(Path(papers_file).parent / cache_name)
def _load_cache(self):
if os.path.exists(self.cache_file):
try:
self.embeddings = np.load(self.cache_file)
if len(self.embeddings) == len(self.papers):
print(f"Loaded cache: {self.embeddings.shape}")
return True
self.embeddings = None
except:
self.embeddings = None
return False
def _save_cache(self):
np.save(self.cache_file, self.embeddings)
print(f"Saved cache: {self.cache_file}")
def _create_text(self, paper):
parts = []
if paper.get('title'):
parts.append(f"Title: {paper['title']}")
if paper.get('abstract'):
parts.append(f"Abstract: {paper['abstract']}")
if paper.get('keywords'):
kw = ', '.join(paper['keywords']) if isinstance(paper['keywords'], list) else paper['keywords']
parts.append(f"Keywords: {kw}")
return ' '.join(parts)
def _embed_openai(self, texts):
if isinstance(texts, str):
texts = [texts]
embeddings = []
batch_size = 100
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
response = self.client.embeddings.create(input=batch, model=self.model_name)
embeddings.extend([item.embedding for item in response.data])
return np.array(embeddings)
def _embed_local(self, texts):
if isinstance(texts, str):
texts = [texts]
return self.model.encode(texts, show_progress_bar=len(texts) > 100)
def compute_embeddings(self, force=False):
if self.embeddings is not None and not force:
print("Using cached embeddings")
return self.embeddings
print(f"Computing embeddings ({self.model_name})...")
texts = [self._create_text(p) for p in self.papers]
if self.model_type == "openai":
self.embeddings = self._embed_openai(texts)
else:
self.embeddings = self._embed_local(texts)
print(f"Computed: {self.embeddings.shape}")
self._save_cache()
return self.embeddings
def search(self, examples=None, query=None, top_k=100):
if self.embeddings is None:
self.compute_embeddings()
if examples:
texts = []
for ex in examples:
text = f"Title: {ex['title']}"
if ex.get('abstract'):
text += f" Abstract: {ex['abstract']}"
texts.append(text)
if self.model_type == "openai":
embs = self._embed_openai(texts)
else:
embs = self._embed_local(texts)
query_emb = np.mean(embs, axis=0).reshape(1, -1)
elif query:
if self.model_type == "openai":
query_emb = self._embed_openai(query).reshape(1, -1)
else:
query_emb = self._embed_local(query).reshape(1, -1)
else:
raise ValueError("Provide either examples or query")
similarities = cosine_similarity(query_emb, self.embeddings)[0]
top_indices = np.argsort(similarities)[::-1][:top_k]
return [{
'paper': self.papers[idx],
'similarity': float(similarities[idx])
} for idx in top_indices]
def display(self, results, n=10):
print(f"\n{'='*80}")
print(f"Top {len(results)} Results (showing {min(n, len(results))})")
print(f"{'='*80}\n")
for i, result in enumerate(results[:n], 1):
paper = result['paper']
sim = result['similarity']
print(f"{i}. [{sim:.4f}] {paper['title']}")
print(f" #{paper.get('number', 'N/A')} | {paper.get('primary_area', 'N/A')}")
print(f" {paper['forum_url']}\n")
def save(self, results, output_file):
with open(output_file, 'w', encoding='utf-8') as f:
json.dump({
'model': self.model_name,
'total': len(results),
'results': results
}, f, ensure_ascii=False, indent=2)
print(f"Saved to {output_file}")

50
mcp/lifespan.py Normal file
View File

@@ -0,0 +1,50 @@
import asyncssh
from contextlib import asynccontextmanager
from collections.abc import AsyncIterator
from dataclasses import dataclass
from starlette.applications import Starlette
# --- 1. SSH 连接参数 ---
REMOTE_HOST = '202.121.182.208'
REMOTE_USER = 'koko125'
# 确保路径正确Windows 路径建议使用正斜杠 '/'
PRIVATE_KEY_PATH = 'D:/tool/tool/id_rsa.txt'
INITIAL_WORKING_DIRECTORY = f'/cluster/home/{REMOTE_USER}/sandbox'
# --- 2. 定义共享上下文 ---
# 这个数据类将持有所有模块需要共享的资源
@dataclass
class SharedAppContext:
"""在整个服务器生命周期内持有的共享资源"""
ssh_connection: asyncssh.SSHClientConnection
sandbox_path: str
# --- 3. 创建生命周期管理器 ---
# 这是整个架构的核心负责在服务器启动时连接SSH在关闭时断开
@asynccontextmanager
async def shared_lifespan(app: Starlette) -> AsyncIterator[SharedAppContext]:
"""
为整个应用管理共享资源的生命周期。
"""
print("主应用启动,正在建立共享 SSH 连接...")
conn = None
try:
# 建立 SSH 连接
conn = await asyncssh.connect(
REMOTE_HOST,
username=REMOTE_USER,
client_keys=[PRIVATE_KEY_PATH]
)
print(f"SSH 连接到 {REMOTE_HOST} 成功!")
# 使用 yield 将创建好的共享资源上下文传递给 Starlette 应用
# 服务器会在此处暂停并开始处理请求
yield {"shared_context": SharedAppContext(ssh_connection=conn, sandbox_path=INITIAL_WORKING_DIRECTORY)}
finally:
# 当服务器关闭时yield 之后的代码会被执行
if conn:
conn.close()
await conn.wait_closed()
print("共享 SSH 连接已关闭。")

55
mcp/main.py Normal file
View File

@@ -0,0 +1,55 @@
# main.py
# 将 Materials CIF MCP 与 System Tools MCP 一起挂载到 Starlette。
# 关键点:
# - 在 lifespan 中启动每个 MCP 的 session_manager.run()(参考 SDK README 的 Starlette 挂载示例与 streamable_http_app 用法 [1]
# - 通过 Mount 指定各自的子路径(如 /system 与 /materials
import contextlib
from starlette.applications import Starlette
from starlette.routing import Mount
from system_tools import create_system_mcp
from materialproject_mcp import create_materials_mcp
from softBV_remake import create_softbv_mcp
from paper_search_mcp import create_paper_search_mcp
from topological_analysis_models import create_topological_analysis_mcp
# 创建 MCP 实例
system_mcp = create_system_mcp()
materials_mcp = create_materials_mcp()
softbv_mcp = create_softbv_mcp()
paper_search_mcp = create_paper_search_mcp()
topological_analysis_mcp = create_topological_analysis_mcp()
# 在 Starlette 的 lifespan 中启动 MCP 的 session manager
@contextlib.asynccontextmanager
async def lifespan(app: Starlette):
async with contextlib.AsyncExitStack() as stack:
await stack.enter_async_context(system_mcp.session_manager.run())
await stack.enter_async_context(materials_mcp.session_manager.run())
await stack.enter_async_context(softbv_mcp.session_manager.run())
await stack.enter_async_context(paper_search_mcp.session_manager.run())
await stack.enter_async_context(topological_analysis_mcp.session_manager.run())
yield # 服务器运行期间
# 退出时自动清理
# 挂载两个 MCP 的 Streamable HTTP App
app = Starlette(
lifespan=lifespan,
routes=[
Mount("/system", app=system_mcp.streamable_http_app()),
Mount("/materials", app=materials_mcp.streamable_http_app()),
Mount("/softBV", app=softbv_mcp.streamable_http_app()),
Mount("/papersearch",app=paper_search_mcp.streamable_http_app()),
Mount("/topologicalAnalysis",app=topological_analysis_mcp.streamable_http_app()),
],
)
# 启动命令(终端执行):
# uvicorn main:app --host 0.0.0.0 --port 8000
# 访问:
# http://localhost:8000/system
# http://localhost:8000/materials
# http://localhost:8000/softBV
# http://localhost:8000/papersearch
# http://localhost:8000/topologicalAnalysis
# 如果需要浏览器客户端访问CORS 暴露 Mcp-Session-Id请参考 README 中的 CORS 配置示例 [1]

513
mcp/materialproject_mcp.py Normal file
View File

@@ -0,0 +1,513 @@
# materials_mp_cif_mcp.py
#
# 说明:
# - 这是一个基于 FastMCP 的“Materials Project CIF”服务器提供
# 1) search_cifs使用 Materials Project 的 mp-api 进行检索两步summary 过滤 + materials 获取结构)
# 2) get_cif按材料 ID 获取结构并导出 CIF 文本pymatgen
# 3) filter_cifs在本地对候选进行硬筛选
# 4) download_cif_bulk批量下载 CIF带并发与进度
# 5) 资源cif://mp/{id} 直接读取 CIF 文本
#
# - 依赖pip install "mcp[cli]" mp_api pymatgen
# - API Key
# 默认从环境变量 MP_API_KEY 读取;若未设置则回退为你提供的 Key仅用于演示不建议在生产中硬编码[3]
# - MCP/Context/Lifespan 的使用方法参考 SDK 文档(工具、结构化输出、进度等)[1]
#
# 引用:
# - MPRester/MP_API_KEY 的用法与推荐的会话管理方式见 Materials Project 文档 [3]
# - summary.search 与 materials.search 的查询与 fields 限制示例见文档 [4]
# - FastMCP 的工具、资源、Context、结构化输出见 README 示例 [1]
import os
import asyncio
from dataclasses import dataclass
from typing import Any, Literal
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
import re
from pydantic import BaseModel, Field
from mp_api.client import MPRester # [3]
from pymatgen.core import Structure # 用于转 CIF
from mcp.server.fastmcp import FastMCP, Context # [1]
from mcp.server.session import ServerSession # [1]
# ========= 配置 =========
# 生产中请通过环境变量 MP_API_KEY 配置;此处为演示而提供回退到你给的 API Key不建议硬编码[3]
DEFAULT_MP_API_KEY = os.getenv("MP_API_KEY", "X0NqhxkFeupGy7k09HRmLS6PI4VmvTBW")
# ========= 数据模型 =========
class LatticeParameters(BaseModel):
a: float | None = Field(default=None, description="a 轴长度 (Å)")
b: float | None = Field(default=None, description="b 轴长度 (Å)")
c: float | None = Field(default=None, description="c 轴长度 (Å)")
alpha: float | None = Field(default=None, description="α 角 (度)")
beta: float | None = Field(default=None, description="β 角 (度)")
gamma: float | None = Field(default=None, description="γ 角 (度)")
class CIFCandidate(BaseModel):
id: str = Field(description="Materials Project ID如 mp-149")
source: Literal["mp"] = Field(default="mp", description="来源库")
chemical_formula: str | None = Field(default=None, description="化学式pretty")
space_group: str | None = Field(default=None, description="空间群符号")
cell_parameters: LatticeParameters | None = Field(default=None, description="晶格参数")
bandgap: float | None = Field(default=None, description="带隙 (eV)")
stability_energy: float | None = Field(default=None, description="E_above_hull (eV)")
magnetic_ordering: str | None = Field(default=None, description="磁性信息")
has_cif: bool = Field(default=True, description="是否存在可导出的 CIF只要有结构即可")
extras: dict[str, Any] | None = Field(default=None, description="原始文档片段(健壮解析)")
class SearchQuery(BaseModel):
# 基础化学约束
include_elements: list[str] | None = Field(default=None, description="包含元素集合(如 ['Si','O']") # [4]
exclude_elements: list[str] | None = Field(default=None, description="排除元素集合(如 ['Li']")
formula: str | None = Field(default=None, description="化学式片段(本地二次过滤)")
# 结构与性能约束summary 可过滤带隙与 E_above_hull[4]
bandgap_min: float | None = None
bandgap_max: float | None = None
e_above_hull_max: float | None = None
# 空间群、晶格范围(在本地二次过滤)
space_group: str | None = None
cell_a_min: float | None = None
cell_a_max: float | None = None
cell_b_min: float | None = None
cell_b_max: float | None = None
cell_c_min: float | None = None
cell_c_max: float | None = None
magnetic_ordering: str | None = None # 本地二次过滤
page: int = Field(default=1, ge=1, description="页码仅本地分页MP 端点请自行扩展)")
per_page: int = Field(default=50, ge=1, le=200, description="每页数量(本地分页)")
sort_by: str | None = Field(default=None, description="排序字段bandgap/e_above_hull等本地排序")
sort_order: Literal["asc", "desc"] | None = Field(default=None, description="排序方向")
class CIFText(BaseModel):
id: str
source: Literal["mp"]
text: str = Field(description="CIF 文件文本内容")
metadata: dict[str, Any] | None = Field(default=None, description="附加元数据(如结构来源)")
class FilterCriteria(BaseModel):
include_elements: list[str] | None = None
exclude_elements: list[str] | None = None
formula: str | None = None
space_group: str | None = None
cell_a_min: float | None = None
cell_a_max: float | None = None
cell_b_min: float | None = None
cell_b_max: float | None = None
cell_c_min: float | None = None
cell_c_max: float | None = None
bandgap_min: float | None = None
bandgap_max: float | None = None
e_above_hull_max: float | None = None
magnetic_ordering: str | None = None
class BulkDownloadItem(BaseModel):
id: str
source: Literal["mp"] = "mp"
success: bool
text: str | None = None
error: str | None = None
class BulkDownloadResult(BaseModel):
items: list[BulkDownloadItem]
total: int
succeeded: int
failed: int
# ========= Lifespan =========
@dataclass
class AppContext:
mpr: MPRester
api_key: str
@asynccontextmanager
async def lifespan(_server: FastMCP) -> AsyncIterator[AppContext]:
# 推荐使用 MP_API_KEY 环境变量;此处回退到提供的 Key 字符串 [3]
api_key = DEFAULT_MP_API_KEY
mpr = MPRester(api_key) # 会话管理在 mp-api 内部 [3]
try:
yield AppContext(mpr=mpr, api_key=api_key)
finally:
# mp-api 未强制要求手动关闭;若未来提供关闭方法,可在此清理 [3]
pass
# ========= 服务器 =========
mcp = FastMCP(
name="Materials Project CIF",
instructions="通过 Materials Project API 检索并导出 CIF 文件。",
lifespan=lifespan,
streamable_http_path="/", # 便于在 Starlette 下挂载到子路径 [1]
)
# ========= 工具实现 =========
def _match_elements_set(text: str | None, required: list[str] | None, excluded: list[str] | None) -> bool:
if text is None:
return not required and not excluded
s = text.upper()
if required:
for el in required:
if el.upper() not in s:
return False
if excluded:
for el in excluded:
if el.upper() in s:
return False
return True
def _lattice_from_structure(struct: Structure | None) -> LatticeParameters | None:
if struct is None:
return None
lat = struct.lattice
return LatticeParameters(
a=float(lat.a),
b=float(lat.b),
c=float(lat.c),
alpha=float(lat.alpha),
beta=float(lat.beta),
gamma=float(lat.gamma),
)
def _sg_symbol_from_doc(doc: Any) -> str | None:
"""
尝试从 materials.search 返回的文档中提取空间群符号。
MP 文档中 symmetry 通常包含 spacegroup 信息(不同版本字段路径可能变化)[4]
"""
try:
sym = getattr(doc, "symmetry", None)
if sym and getattr(sym, "spacegroup", None):
sg = sym.spacegroup
# 常见字段symbol 或 international
return getattr(sg, "symbol", None) or getattr(sg, "international", None)
except Exception:
pass
return None
async def _to_thread(fn, *args, **kwargs):
return await asyncio.to_thread(fn, *args, **kwargs)
def _parse_elements_from_formula(formula: str) -> list[str]:
"""从化学式中解析元素符号(简易版)。"""
if not formula:
return []
elems = re.findall(r"[A-Z][a-z]?", formula)
seen, ordered = set(), []
for e in elems:
if e not in seen:
seen.add(e)
ordered.append(e)
return ordered
@mcp.tool()
async def search_cifs(
query: SearchQuery,
ctx: Context[ServerSession, AppContext],
) -> list[CIFCandidate]:
"""
通过 summary 做初筛(限制返回数量),再用 materials 拉取当前页的结构与对称性。
修复点:
- 仅请求当前页chunk_size/per_page + num_chunks=1避免全库下载超时
- 绝不在返回值中包含 mp-api 文档对象或 Structureextras 只放可序列化轻量字段
"""
app = ctx.request_context.lifespan_context
mpr = app.mpr
# 将 formula 自动转为元素集合(若用户未显式提供)
include_elements = query.include_elements or _parse_elements_from_formula(query.formula or "")
# 至少要有一个约束,避免全库扫描
if not include_elements and not query.exclude_elements and query.bandgap_min is None \
and query.bandgap_max is None and query.e_above_hull_max is None:
raise ValueError("请至少提供一个过滤条件,例如 include_elements、bandgap 或 energy_above_hull 上限,以避免超时。")
# 1) summary 初筛:只取“当前页”的文档(轻字段)
summary_kwargs: dict[str, Any] = {}
if include_elements:
summary_kwargs["elements"] = include_elements # MP 支持按元素集合检索
if query.exclude_elements:
summary_kwargs["exclude_elements"] = query.exclude_elements
if query.bandgap_min is not None or query.bandgap_max is not None:
summary_kwargs["band_gap"] = (
float(query.bandgap_min) if query.bandgap_min is not None else None,
float(query.bandgap_max) if query.bandgap_max is not None else None,
)
if query.e_above_hull_max is not None:
summary_kwargs["energy_above_hull"] = (None, float(query.e_above_hull_max))
fields_summary = [
"material_id",
"formula_pretty",
"band_gap",
"energy_above_hull",
"ordering",
]
# 关键:限制 summary 返回规模
try:
summary_docs = await asyncio.to_thread(
mpr.materials.summary.search,
fields=fields_summary,
chunk_size=query.per_page, # 每块(页)大小
num_chunks=1, # 只取一块 => 只取当前页数量
**summary_kwargs,
)
except TypeError:
# 如果你的 mp-api 版本不支持上述参数,提示升级,避免一次性抓回十几万条导致超时
raise RuntimeError(
"当前 mp-api 版本不支持 chunk_size/num_chunks请升级 mp-apipip install -U mp_api"
)
if not summary_docs:
await ctx.debug("summary 检索结果为空")
return []
# 本地排序(如需)
results_all: list[CIFCandidate] = []
for sdoc in summary_docs:
e_hull = getattr(sdoc, "energy_above_hull", None)
bandgap = getattr(sdoc, "band_gap", None)
ordering = getattr(sdoc, "ordering", None)
# 暂不放 structure/doc 到 extras避免序列化错误
results_all.append(
CIFCandidate(
id=sdoc.material_id,
source="mp",
chemical_formula=getattr(sdoc, "formula_pretty", None),
space_group=None,
cell_parameters=None,
bandgap=float(bandgap) if bandgap is not None else None,
stability_energy=float(e_hull) if e_hull is not None else None,
magnetic_ordering=ordering,
has_cif=True,
extras={
"summary": {
"material_id": sdoc.material_id,
"formula_pretty": getattr(sdoc, "formula_pretty", None),
"band_gap": float(bandgap) if bandgap is not None else None,
"energy_above_hull": float(e_hull) if e_hull is not None else None,
"ordering": ordering,
}
},
)
)
# 本地分页summary 已限制返回条数,这里只做安全截断和排序
if query.sort_by:
reverse = query.sort_order == "desc"
key_map = {
"bandgap": lambda c: (c.bandgap if c.bandgap is not None else float("inf")),
"e_above_hull": lambda c: (c.stability_energy if c.stability_energy is not None else float("inf")),
"energy_above_hull": lambda c: (c.stability_energy if c.stability_energy is not None else float("inf")),
}
if query.sort_by in key_map:
results_all.sort(key=key_map[query.sort_by], reverse=reverse)
page_slice = results_all[: query.per_page]
page_ids = [c.id for c in page_slice]
# 2) 只为当前页材料获取结构与对称性
if not page_ids:
return []
fields_materials = ["material_id", "structure", "symmetry"]
materials_docs = await asyncio.to_thread(
mpr.materials.search,
material_ids=page_ids,
fields=fields_materials,
)
mat_by_id = {doc.material_id: doc for doc in materials_docs}
# 合并结构/空间群,仍然不返回不可序列化对象
merged: list[CIFCandidate] = []
for c in page_slice:
mdoc = mat_by_id.get(c.id)
struct = getattr(mdoc, "structure", None) if mdoc else None
# 空间群符号
sg = None
try:
sym = getattr(mdoc, "symmetry", None)
if sym and getattr(sym, "spacegroup", None):
sgobj = sym.spacegroup
sg = getattr(sgobj, "symbol", None) or getattr(sgobj, "international", None)
except Exception:
pass
c.space_group = sg
c.cell_parameters = _lattice_from_structure(struct)
c.has_cif = struct is not None
# 保持 extras 为轻量 JSON
c.extras = {
**(c.extras or {}),
"materials": {"has_structure": struct is not None},
}
merged.append(c)
await ctx.debug(f"返回 {len(merged)} 条(已限制每页 {query.per_page} 条)")
return merged
@mcp.tool()
async def get_cif(
id: str,
ctx: Context[ServerSession, AppContext],
) -> CIFText:
"""
使用 materials.search 获取结构并导出 CIF 文本。
- fields=["structure"] 以提升速度 [4]
- 结构到 CIF 通过 pymatgen 的 Structure.to(fmt="cif")
"""
app = ctx.request_context.lifespan_context
mpr = app.mpr
await ctx.info(f"获取结构并导出 CIF{id}")
docs = await _to_thread(mpr.materials.search, material_ids=[id], fields=["material_id", "structure"]) # [4]
if not docs:
raise ValueError(f"未找到材料:{id}")
doc = docs[0]
struct: Structure | None = getattr(doc, "structure", None)
if struct is None:
raise ValueError(f"材料 {id} 无结构数据,无法导出 CIF")
cif_text = struct.to(fmt="cif") # 直接输出 CIF 文本
return CIFText(id=id, source="mp", text=cif_text, metadata={"from": "materials.search"})
@mcp.tool()
async def filter_cifs(
candidates: list[CIFCandidate],
criteria: FilterCriteria,
ctx: Context[ServerSession, AppContext],
) -> list[CIFCandidate]:
def in_range(v: float | None, vmin: float | None, vmax: float | None) -> bool:
if v is None:
return vmin is None and vmax is None
if vmin is not None and v < vmin:
return False
if vmax is not None and v > vmax:
return False
return True
filtered: list[CIFCandidate] = []
for c in candidates:
# 元素集合近似匹配(基于化学式字符串)
if not _match_elements_set(c.chemical_formula, criteria.include_elements, criteria.exclude_elements):
continue
# 化学式包含判断
if criteria.formula:
if not c.chemical_formula or criteria.formula.upper() not in c.chemical_formula.upper():
continue
# 空间群
if criteria.space_group:
spg = (c.space_group or "").lower()
if criteria.space_group.lower() not in spg:
continue
# 晶格范围
cp = c.cell_parameters or LatticeParameters()
if not in_range(cp.a, criteria.cell_a_min, criteria.cell_a_max):
continue
if not in_range(cp.b, criteria.cell_b_min, criteria.cell_b_max):
continue
if not in_range(cp.c, criteria.cell_c_min, criteria.cell_c_max):
continue
# 带隙与稳定性
if not in_range(c.bandgap, criteria.bandgap_min, criteria.bandgap_max):
continue
if criteria.e_above_hull_max is not None:
if c.stability_energy is None or c.stability_energy > criteria.e_above_hull_max:
continue
# 磁性
if criteria.magnetic_ordering:
mo = "" # 当前未填充,保留兼容
if criteria.magnetic_ordering.upper() not in mo.upper():
continue
filtered.append(c)
await ctx.debug(f"筛选后:{len(filtered)} / 原始 {len(candidates)}")
return filtered
@mcp.tool()
async def download_cif_bulk(
ids: list[str],
concurrency: int = 4,
ctx: Context[ServerSession, AppContext] | None = None,
) -> BulkDownloadResult:
semaphore = asyncio.Semaphore(max(1, concurrency))
items: list[BulkDownloadItem] = []
async def worker(mid: str, idx: int) -> None:
nonlocal items
try:
async with semaphore:
# 复用当前 ctx
assert ctx is not None
cif = await get_cif(id=mid, ctx=ctx)
items.append(BulkDownloadItem(id=mid, success=True, text=cif.text))
except Exception as e:
items.append(BulkDownloadItem(id=mid, success=False, error=str(e)))
if ctx is not None:
progress = (idx + 1) / len(ids) if len(ids) > 0 else 1.0
await ctx.report_progress(progress=progress, total=1.0, message=f"{idx + 1}/{len(ids)}") # [1]
tasks = [worker(mid, i) for i, mid in enumerate(ids)]
await asyncio.gather(*tasks)
succeeded = sum(1 for it in items if it.success)
failed = len(items) - succeeded
return BulkDownloadResult(items=items, total=len(items), succeeded=succeeded, failed=failed)
# ========= 资源cif://mp/{id} =========
@mcp.resource("cif://mp/{id}")
async def cif_resource(id: str) -> str:
"""
基于 URI 的 CIF 内容读取。资源函数无法直接拿到 Context
这里简单地新的 MPRester 会话读取结构并转 CIF仅资源路径用途
"""
# 资源中使用一次性会话(不复用 lifespan适合偶发读取 [1]
mpr = MPRester(DEFAULT_MP_API_KEY) # [3]
try:
docs = mpr.materials.search(material_ids=[id], fields=["material_id", "structure"]) # [4]
if not docs or getattr(docs[0], "structure", None) is None:
raise ValueError(f"材料 {id} 无结构数据")
struct: Structure = docs[0].structure
return struct.to(fmt="cif")
finally:
pass
def create_materials_mcp() -> FastMCP:
"""供 Starlette 挂载的工厂函数。"""
return mcp
if __name__ == "__main__":
# 独立运行(非挂载模式),使用 Streamable HTTP 传输 [1]
mcp.run(transport="streamable-http")

172
mcp/paper_search_mcp.py Normal file
View File

@@ -0,0 +1,172 @@
# paper_search_mcp.py (重构版)
import os
import asyncio
from dataclasses import dataclass
from typing import Any
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from pydantic import BaseModel, Field
# 假设您的 search.py 提供了 PaperSearcher 类
from SearchPaperByEmbedding.search import PaperSearcher
from mcp.server.fastmcp import FastMCP, Context
from mcp.server.session import ServerSession
# ========= 1. 配置与常量 (修正版) =========
# 获取当前脚本文件所在的目录的绝对路径
_CURRENT_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
# 基于项目根目录构造资源文件的绝对路径
DATA_DIR = os.path.join(_CURRENT_SCRIPT_DIR,"SearchPaperByEmbedding")
PAPERS_JSON_PATH = os.path.join(DATA_DIR, "iclr2026_papers.json")
# 打印路径用于调试,确保它是正确的
print(f"DEBUG: Trying to load papers from: {PAPERS_JSON_PATH}")
MODEL_TYPE = os.getenv("EMBEDDING_MODEL_TYPE", "local")
# ========= 2. 数据模型 =========
class Paper(BaseModel):
"""单篇论文的详细信息"""
title: str
authors: list[str]
abstract: str
pdf_url: str | None = None
similarity_score: float = Field(description="与查询的相似度分数,越高越相关")
class SearchQuery(BaseModel):
"""论文搜索的查询结构"""
title: str
abstract: str | None = Field(None, description="可选的摘要,以提供更丰富的上下文")
# ========= 3. Lifespan (资源加载与管理) =========
@dataclass
class PaperSearchContext:
"""在服务器生命周期内共享的、已初始化的 PaperSearcher 实例"""
searcher: PaperSearcher
@asynccontextmanager
async def paper_search_lifespan(_server: FastMCP) -> AsyncIterator[PaperSearchContext]:
"""
FastMCP 生命周期:在服务器启动时初始化 PaperSearcher 并计算 Embeddings。
这确保了最耗时的部分只在启动时执行一次。
"""
print(f"正在使用 '{MODEL_TYPE}' 模型初始化论文搜索引擎...")
# 使用 to_thread 运行同步的初始化代码
def _initialize_searcher():
# 1. 初始化
searcher = PaperSearcher(
PAPERS_JSON_PATH,
model_type=MODEL_TYPE
)
# 2. 计算/加载 Embeddings
print("正在计算或加载论文 Embeddings...")
searcher.compute_embeddings()
print("Embeddings 加载完成。")
return searcher
searcher = await asyncio.to_thread(_initialize_searcher)
print("论文搜索引擎已准备就绪!")
try:
# 使用 yield 将创建好的共享资源上下文传递给 MCP
yield PaperSearchContext(searcher=searcher)
finally:
print("论文搜索引擎服务关闭。")
# ========= 4. MCP 服务器实例 =========
mcp = FastMCP(
name="Paper Search",
instructions="通过 embedding 相似度检索 ICLR 2026 论文的工具。",
lifespan=paper_search_lifespan,
streamable_http_path="/",
)
# ========= 5. 工具实现 =========
@mcp.tool()
async def search_papers(
queries: list[SearchQuery],
top_k: int = 3,
ctx: Context[ServerSession, PaperSearchContext] | None = None,
) -> list[Paper]:
"""
根据给定的一个或多个查询(包含标题和可选的摘要),使用语义相似度搜索论文。
返回最相关的 top_k 个结果。
"""
if ctx is None:
raise ValueError("Context is required for this operation.")
if not queries:
return []
app_ctx = ctx.request_context.lifespan_context
searcher = app_ctx.searcher
# 预处理查询,确保 title 和 abstract 至少有一个存在
examples_to_search = []
for q in queries:
query_dict = q.model_dump(exclude_none=True)
if "title" not in query_dict and "abstract" in query_dict:
query_dict["title"] = query_dict["abstract"]
elif "abstract" not in query_dict and "title" in query_dict:
query_dict["abstract"] = query_dict["title"]
examples_to_search.append(query_dict)
if not examples_to_search:
return []
query_display = queries[0].title or queries[0].abstract or "Empty Query"
await ctx.info(f"正在以查询 '{query_display}' 等内容搜索排名前 {top_k} 的论文...")
# 调用底层的 search 方法
results = await asyncio.to_thread(
searcher.search,
examples=examples_to_search,
top_k=top_k
)
# --- 关键修正:按照 search.py 返回的正确结构进行解析 ---
papers = []
for res in results:
paper_data = res.get('paper', {}) # 获取嵌套的论文信息字典
similarity = res.get('similarity', 0.0) # 获取相似度分数
papers.append(
Paper(
title=paper_data.get("title", "N/A"),
authors=paper_data.get("authors", []),
abstract=paper_data.get("abstract", ""),
# 注意:原始数据中 pdf url 的键是 'forum_url'
pdf_url=paper_data.get("forum_url"),
similarity_score=similarity # 使用正确的相似度分数
)
)
await ctx.debug(f"找到 {len(papers)} 篇相关论文。")
return papers
# ========= 6. 工厂函数 (用于 Starlette 集成) =========
def create_paper_search_mcp() -> FastMCP:
"""供 Starlette 挂载的工厂函数。"""
return mcp
if __name__ == "__main__":
# 如果直接运行此文件,则启动一个独立的 MCP 服务器
mcp.run(transport="streamable-http")

464
mcp/softBV.py Normal file
View File

@@ -0,0 +1,464 @@
# softbv_mcp.py
#
# 在远程服务器上激活 softBV 环境并执行计算(支持 --md 与 --gen-cube 两个专用工具)。
# - 生命周期:建立 SSH 连接并注入上下文
# - 激活环境source /cluster/home/koko125/script/softBV.sh
# - 工作目录:/cluster/home/koko125/sandbox
# - 可执行文件:/cluster/home/koko125/tool/softBV-GUI_linux/bin/softBV.x
#
# 依赖:
# pip install "mcp[cli]" asyncssh pydantic
#
# 用法Starlette 挂载示例见你现有 main.py导入 create_softbv_mcp 即可):
# from softbv_mcp import create_softbv_mcp
# softbv_mcp = create_softbv_mcp()
# Mount("/softbv", app=softbv_mcp.streamable_http_app())
#
# 可通过环境变量覆盖连接信息:
# REMOTE_HOST, REMOTE_USER, PRIVATE_KEY_PATH, SOFTBV_PROFILE, SOFTBV_BIN, DEFAULT_WORKDIR
import os
import posixpath
import asyncio
from dataclasses import dataclass
from socket import socket
from typing import Any
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
import asyncssh
from pydantic import BaseModel, Field
from mcp.server.fastmcp import FastMCP, Context
from mcp.server.session import ServerSession
from lifespan import REMOTE_HOST, REMOTE_USER, PRIVATE_KEY_PATH
def shell_quote(arg: str) -> str:
"""安全地把字符串作为单个 shell 参数POSIX"""
return "'" + str(arg).replace("'", "'\"'\"'") + "'"
# 如果你已经定义了这些常量与路径,可保持复用;也可按需改为你自己的配置源
# 固定的 softBV 环境信息(可用环境变量覆盖)
SOFTBV_PROFILE = os.getenv("SOFTBV_PROFILE", "/cluster/home/koko125/script/softBV.sh")
SOFTBV_BIN = os.getenv("SOFTBV_BIN", "/cluster/home/koko125/tool/softBV-GUI_linux/bin/softBV.x")
DEFAULT_WORKDIR = os.getenv("DEFAULT_WORKDIR", "/cluster/home/koko125/sandbox")
@dataclass
class SoftBVContext:
ssh_connection: asyncssh.SSHClientConnection
workdir: str
profile: str
bin_path: str
@asynccontextmanager
async def softbv_lifespan(_server) -> AsyncIterator[SoftBVContext]:
"""
FastMCP 生命周期:建立 SSH 连接并注入 softBV 上下文。
- 不再做额外的 DNS 解析或自定义异步步骤,避免 socket.SOCK_STREAM 的环境覆盖问题
- 仅负责连接与清理;工具中通过 ctx.request_context.lifespan_context 访问该上下文 [1]
"""
# 允许用环境变量覆盖连接信息
host = os.getenv("REMOTE_HOST", REMOTE_HOST)
user = os.getenv("REMOTE_USER", REMOTE_USER)
key_path = os.getenv("PRIVATE_KEY_PATH", PRIVATE_KEY_PATH)
conn: asyncssh.SSHClientConnection | None = None
try:
conn = await asyncssh.connect(
host,
username=user,
client_keys=[key_path],
known_hosts=None, # 如需主机指纹校验,可移除此参数
connect_timeout=15, # 避免长时间挂起
)
yield SoftBVContext(
ssh_connection=conn,
workdir=DEFAULT_WORKDIR,
profile=SOFTBV_PROFILE,
bin_path=SOFTBV_BIN,
)
finally:
if conn:
conn.close()
await conn.wait_closed()
async def run_in_softbv_env(
conn: asyncssh.SSHClientConnection,
profile_path: str,
cmd: str,
cwd: str | None = None,
check: bool = True,
) -> asyncssh.SSHCompletedProcess:
"""
在远端 bash 会话中激活 softBV 环境并执行 cmd。
如提供 cwd则先 cd 到该目录。
"""
parts = []
if cwd:
parts.append(f"cd {shell_quote(cwd)}")
parts.append(f"source {shell_quote(profile_path)}")
parts.append(cmd)
composite = "; ".join(parts)
full = f"bash -lc {shell_quote(composite)}"
return await conn.run(full, check=check)
# ===== MCP 服务器 =====
mcp = FastMCP(
name="softBV Tools",
instructions="在远程服务器上激活 softBV 环境并执行相关计算的工具集。",
lifespan=softbv_lifespan,
streamable_http_path="/",
)
# ===== 辅助:列目录用于识别新生成文件 =====
async def _listdir(conn: asyncssh.SSHClientConnection, path: str) -> list[str]:
async with conn.start_sftp_client() as sftp:
try:
return await sftp.listdir(path)
except Exception:
return []
# ===== 结构化输入模型:--md =====
class SoftBVMDArgs(BaseModel):
input_cif: str = Field(description="远程 CIF 文件路径(相对或绝对),作为 --md 的输入")
# 位置参数按帮助中的顺序None 表示不提供,让程序使用默认)
type: str | None = Field(default=None, description="conducting ion 类型(例如 'Li'")
os: int | None = Field(default=None, description="conducting ion 氧化态(整数)")
sf: float | None = Field(default=None, description="screening factor非正值使用默认")
temperature: float | None = Field(default=None, description="温度 K非正值默认 300")
t_end: float | None = Field(default=None, description="生产时间 ps非正值默认 10.0")
t_equil: float | None = Field(default=None, description="平衡时间 ps非正值默认 2.0")
dt: float | None = Field(default=None, description="时间步长 ps非正值默认 0.001")
t_log: float | None = Field(default=None, description="采样间隔 ps非正值每 100 步)")
cwd: str | None = Field(default=None, description="远程工作目录(默认使用生命周期中的 workdir")
def _build_md_cmd(bin_path: str, args: SoftBVMDArgs, workdir: str) -> str:
input_abs = args.input_cif
if not input_abs.startswith("/"):
input_abs = posixpath.normpath(posixpath.join(workdir, args.input_cif))
parts: list[str] = [shell_quote(bin_path), shell_quote("--md"), shell_quote(input_abs)]
for val in [args.type, args.os, args.sf, args.temperature, args.t_end, args.t_equil, args.dt, args.t_log]:
if val is not None:
parts.append(shell_quote(str(val)))
return " ".join(parts)
# ===== 结构化输入模型:--gen-cube =====
class SoftBVGenCubeArgs(BaseModel):
input_cif: str = Field(description="远程 CIF 文件路径(相对或绝对),作为 --gen-cube 的输入")
type: str | None = Field(default=None, description="conducting ion 类型(如 'Li'")
os: int | None = Field(default=None, description="conducting ion 氧化态(整数)")
sf: float | None = Field(default=None, description="screening factor非正值使用默认")
resolution: float | None = Field(default=None, description="体素分辨率(默认约 0.1")
ignore_conducting_ion: bool = Field(default=False, description="flag:ignore_conducting_ion")
periodic: bool = Field(default=True, description="flag:periodic默认 True")
output_name: str | None = Field(default=None, description="输出文件名前缀(可选)")
cwd: str | None = Field(default=None, description="远程工作目录(默认使用生命周期中的 workdir")
def _build_gen_cube_cmd(bin_path: str, args: SoftBVGenCubeArgs, workdir: str) -> str:
input_abs = args.input_cif
if not input_abs.startswith("/"):
input_abs = posixpath.normpath(posixpath.join(workdir, args.input_cif))
parts: list[str] = [shell_quote(bin_path), shell_quote("--gen-cube"), shell_quote(input_abs)]
for val in [args.type, args.os, args.sf, args.resolution]:
if val is not None:
parts.append(shell_quote(str(val)))
if args.ignore_conducting_ion:
parts.append(shell_quote("--flag:ignore_conducting_ion"))
if args.periodic:
parts.append(shell_quote("--flag:periodic"))
if args.output_name:
parts.append(shell_quote(args.output_name))
return " ".join(parts)
# ===== 工具:环境信息检查 =====
# 工具:环境信息检查(修复版,避免超时)
@mcp.tool()
async def softbv_info(ctx: Context[ServerSession, SoftBVContext]) -> dict[str, Any]:
"""
快速自检:
- SFTP 检查工作目录、激活脚本、二进制是否存在/可执行(无需运行 softBV.x
- 激活环境后仅输出标记与当前工作目录,避免长输出或阻塞
"""
app = ctx.request_context.lifespan_context
conn = app.ssh_connection
# 1) 通过 SFTP 快速检查文件与目录状态(不会长时间阻塞)
def stat_path_safe(path: str) -> dict[str, Any]:
return {"exists": False, "is_exec": False, "size": None}
workdir_info = stat_path_safe(app.workdir)
profile_info = stat_path_safe(app.profile)
bin_info = stat_path_safe(app.bin_path)
try:
async with conn.start_sftp_client() as sftp:
# workdir
try:
attrs = await sftp.stat(app.workdir)
workdir_info["exists"] = True
workdir_info["size"] = int(attrs.size or 0)
except Exception:
pass
# profile
try:
attrs = await sftp.stat(app.profile)
profile_info["exists"] = True
profile_info["size"] = int(attrs.size or 0)
perms = int(attrs.permissions or 0)
profile_info["is_exec"] = bool(perms & 0o111)
except Exception:
pass
# bin
try:
attrs = await sftp.stat(app.bin_path)
bin_info["exists"] = True
bin_info["size"] = int(attrs.size or 0)
perms = int(attrs.permissions or 0)
bin_info["is_exec"] = bool(perms & 0o111)
except Exception:
pass
except Exception as e:
await ctx.warning(f"SFTP 检查失败: {e}")
# 2) 激活环境并做极简命令(避免 softBV.x --help 的长输出)
# 仅返回当前用户、PWD 与二进制可执行判断;不实际运行 softBV.x
cmd = "echo __SOFTBV_READY__ && echo $USER && pwd && (test -x " + shell_quote(app.bin_path) + " && echo __BIN_OK__ || echo __BIN_NOT_EXEC__)"
proc = await run_in_softbv_env(conn, app.profile, cmd=cmd, cwd=app.workdir, check=False)
# 解析输出行
lines = proc.stdout.splitlines() if proc.stdout else []
ready = "__SOFTBV_READY__" in lines
user = None
pwd = None
bin_ok = "__BIN_OK__" in lines
# 尝试定位 user/pwdready 之后的两行)
if ready:
idx = lines.index("__SOFTBV_READY__")
if len(lines) > idx + 1:
user = lines[idx + 1].strip()
if len(lines) > idx + 2:
pwd = lines[idx + 2].strip()
result = {
"host": os.getenv("REMOTE_HOST", REMOTE_HOST),
"user": os.getenv("REMOTE_USER", REMOTE_USER),
"workdir": app.workdir,
"profile": app.profile,
"bin_path": app.bin_path,
"sftp_check": {
"workdir": workdir_info,
"profile": profile_info,
"bin": bin_info,
},
"activate_ready": ready,
"pwd": pwd,
"bin_is_executable": bin_ok or bin_info["is_exec"],
"exit_status": proc.exit_status,
"stderr_head": "\n".join(proc.stderr.splitlines()[:10]) if proc.stderr else "",
}
# 友好日志
if not ready:
await ctx.warning("softBV 环境未就绪(可能 source 脚本路径问题或权限不足)")
if not result["bin_is_executable"]:
await ctx.warning("softBV 二进制不可执行或不存在,请检查 bin_path 与权限chmod +x")
if proc.exit_status != 0 and not proc.stderr:
await ctx.debug("命令非零退出但无 stderr可能是某些子测试返回非零导致")
return result
# ===== 工具:--md =====
@mcp.tool()
async def softbv_md(req: SoftBVMDArgs, ctx: Context[ServerSession, SoftBVContext]) -> dict[str, Any]:
"""
执行 softBV.x --md返回结构化结果
- cmd/cwd/exit_status/stdout/stderr
- new_files执行后新增文件列表便于定位输出
"""
app = ctx.request_context.lifespan_context
conn = app.ssh_connection
workdir = req.cwd or app.workdir
cmd = _build_md_cmd(app.bin_path, req, workdir)
await ctx.info(f"softBV --md 执行: {cmd} (cwd={workdir})")
pre_list = await _listdir(conn, workdir)
proc = await run_in_softbv_env(conn, app.profile, cmd=cmd, cwd=workdir, check=False)
post_list = await _listdir(conn, workdir)
new_files = sorted(set(post_list) - set(pre_list))
if proc.exit_status == 0:
await ctx.debug(f"--md 成功,新文件 {len(new_files)}")
else:
await ctx.warning(f"--md 非零退出: {proc.exit_status}")
return {
"cmd": cmd,
"cwd": workdir,
"exit_status": proc.exit_status,
"stdout": proc.stdout,
"stderr": proc.stderr,
"new_files": new_files,
}
async def run_in_softbv_env_stream(
conn: asyncssh.SSHClientConnection,
profile_path: str,
cmd: str,
cwd: str | None = None,
) -> asyncssh.SSHClientProcess:
parts = []
if cwd:
parts.append(f"cd {shell_quote(cwd)}")
parts.append(f"source {shell_quote(profile_path)} >/dev/null 2>&1 || true")
parts.append(cmd)
composite = "; ".join(parts)
full = f"bash -lc {shell_quote(composite)}"
# 不阻塞,返回进程句柄
proc = await conn.create_process(full)
return proc
# 轮询目录,识别新生成文件
async def _listdir(conn: asyncssh.SSHClientConnection, path: str) -> list[str]:
async with conn.start_sftp_client() as sftp:
try:
return await sftp.listdir(path)
except Exception:
return []
# 轮询日志文件大小(作为心跳/粗略进度依据)
async def _stat_size(conn: asyncssh.SSHClientConnection, path: str) -> int | None:
async with conn.start_sftp_client() as sftp:
try:
attrs = await sftp.stat(path)
return int(attrs.size or 0)
except Exception:
return None
# 构造 gen-cube 命令(保持你之前的参数拼接逻辑)
def _build_gen_cube_cmd(bin_path: str, args: SoftBVGenCubeArgs, workdir: str, log_path: str | None = None) -> str:
input_abs = args.input_cif
if not input_abs.startswith("/"):
input_abs = posixpath.normpath(posixpath.join(workdir, args.input_cif))
parts: list[str] = [shell_quote(bin_path), shell_quote("--gen-cube"), shell_quote(input_abs)]
for val in [args.type, args.os, args.sf, args.resolution]:
if val is not None:
parts.append(shell_quote(str(val)))
if args.ignore_conducting_ion:
parts.append(shell_quote("--flag:ignore_conducting_ion"))
if args.periodic:
parts.append(shell_quote("--flag:periodic"))
if args.output_name:
parts.append(shell_quote(args.output_name))
cmd = " ".join(parts)
# 将输出重定向到日志,便于轮询
if log_path:
cmd = f"{cmd} > {shell_quote(log_path)} 2>&1"
return cmd
# 修复版:长时运行的 softbv_gen_cube带心跳与超时保护
@mcp.tool()
async def softbv_gen_cube(req: SoftBVGenCubeArgs, ctx: Context[ServerSession, SoftBVContext]) -> dict[str, Any]:
"""
执行 softBV.x --gen-cube支持长任务心跳避免在 <25min 内被客户端强制超时。
- 每隔 10s 上报一次进度(心跳),包含已用时/日志大小/新增文件数
- 结束后返回 stdout_head来自日志文件片段、stderr_head如有、exit_status、新增文件
"""
app = ctx.request_context.lifespan_context
conn = app.ssh_connection
workdir = req.cwd or app.workdir
# 预先记录目录内容,用于结束后差集
before = await _listdir(conn, workdir)
# 远端日志文件路径(按时间戳命名)
import time
log_name = f"softbv_gencube_{int(time.time())}.log"
log_path = posixpath.join(workdir, log_name)
# 启动长任务,不阻塞当前协程
cmd = _build_gen_cube_cmd(app.bin_path, req, workdir, log_path=log_path)
await ctx.info(f"启动 --gen-cube: {cmd}")
proc = await run_in_softbv_env_stream(conn, app.profile, cmd=cmd, cwd=workdir)
# 心跳循环:直到进程退出
start_ts = time.time()
heartbeat_sec = 10 # 每 10 秒发送一次心跳
max_guard_min = 60 # 保险上限(服务端不主动终止;如客户端有限制可调大)
try:
while True:
# 进程是否已退出
if proc.exit_status is not None:
break
# 采集状态:已用时、日志大小、新增文件数
elapsed = time.time() - start_ts
log_size = await _stat_size(conn, log_path)
now_files = await _listdir(conn, workdir)
new_files_count = len(set(now_files) - set(before))
# 这里无法获知真实百分比,使用“心跳/已用时提示”
await ctx.report_progress(
progress=min(elapsed / (25 * 60), 0.99), # 以 25min 为目标上限做近似刻度
total=1.0,
message=f"gen-cube 运行中: 已用时 {int(elapsed)}s, 日志 {log_size or 0}B, 新文件 {new_files_count}",
)
# 避免客户端超时:持续心跳即可。[1]
await asyncio.sleep(heartbeat_sec)
# 简易守护:超过 max_guard_min 仍未结束也不强制中断(由远端决定)
if elapsed > max_guard_min * 60:
await ctx.warning("任务已超过守护上限时间,仍在运行(未强制中断)。如需更长时间,请增大上限。")
# 不 break继续等待交由远端任务完成
finally:
# 等待进程真正结束(如果已结束,这里是快速返回)
await proc.wait()
# 结束后采集结果
exit_status = proc.exit_status
after = await _listdir(conn, workdir)
new_files = sorted(set(after) - set(before))
# 读取日志片段(头/尾),帮助定位输出
async with conn.start_sftp_client() as sftp:
head = ""
tail = ""
try:
async with sftp.open(log_path, "rb") as f:
content = await f.read()
text = content.decode("utf-8", errors="replace")
lines = text.splitlines()
head = "\n".join(lines[:40])
tail = "\n".join(lines[-40:])
except Exception:
pass
# 输出结构化结果
result = {
"cmd": cmd,
"cwd": workdir,
"exit_status": exit_status,
"log_file": log_path,
"stdout_head": head, # 代替一次性 stdout避免大输出
"stderr_head": "", # 统一日志到文件stderr_head 可留空
"new_files": new_files,
"elapsed_sec": int(time.time() - start_ts),
}
if exit_status == 0:
await ctx.info(f"gen-cube 完成,用时 {result['elapsed_sec']}s新文件 {len(new_files)}")
else:
await ctx.warning(f"gen-cube 退出码 {exit_status},请查看日志 {log_path}")
return result
def create_softbv_mcp() -> FastMCP:
"""供外部Starlette导入的工厂函数。"""
return mcp
if __name__ == "__main__":
mcp.run(transport="streamable-http")

1659
mcp/softBV_remake.py Normal file

File diff suppressed because it is too large Load Diff

407
mcp/system_tools.py Normal file
View File

@@ -0,0 +1,407 @@
# system_tools.py
import posixpath
import stat as pystat
from shlex import shlex
from typing import Any
from contextlib import asynccontextmanager
from collections.abc import AsyncIterator
import asyncssh
from mcp.server.fastmcp import FastMCP, Context
from mcp.server.session import ServerSession
# 在 system_tools.py 顶部添加
def shell_quote(arg: str) -> str:
"""
安全地把字符串作为单个 shell 参数:
- 外层用单引号包裹
- 内部的单引号 ' 替换为 '\'' 序列
适用于远端 Linux shell 命令拼接
"""
return "'" + arg.replace("'", "'\"'\"'") + "'"
# 复用你的配置与数据类
from lifespan import (
SharedAppContext,
REMOTE_HOST,
REMOTE_USER,
PRIVATE_KEY_PATH,
INITIAL_WORKING_DIRECTORY,
)
# —— 1) 定义 FastMCP 的生命周期,在启动时建立 SSH 连接,关闭时断开 ——
@asynccontextmanager
async def system_lifespan(_server: FastMCP) -> AsyncIterator[SharedAppContext]:
"""
FastMCP 生命周期:建立并注入共享的 SSH 连接与沙箱根路径。
说明:这是 MCP 自己的生命周期,工具里通过 ctx.request_context.lifespan_context 访问。
"""
conn: asyncssh.SSHClientConnection | None = None
try:
# 建立 SSH 连接
conn = await asyncssh.connect(
REMOTE_HOST,
username=REMOTE_USER,
client_keys=[PRIVATE_KEY_PATH],
)
# 将类型安全的共享上下文注入 MCP 生命周期
yield SharedAppContext(ssh_connection=conn, sandbox_path=INITIAL_WORKING_DIRECTORY)
finally:
# 关闭连接
if conn:
conn.close()
await conn.wait_closed()
def create_system_mcp() -> FastMCP:
"""创建一个包含系统操作工具的 MCP 实例。"""
system_mcp = FastMCP(
name="System Tools",
instructions="用于在远程服务器上进行基本文件和目录操作的工具集。",
streamable_http_path="/",
lifespan=system_lifespan, # 关键:把生命周期传给 FastMCP [1]
)
def _safe_join(sandbox_root: str, relative_path: str) -> str:
"""
将用户提供的相对路径映射到沙箱根目录内的规范化绝对路径。
- 统一使用 POSIX 语义(远端 Linux)
- 禁止使用以 '/' 开头的绝对路径
- 禁止 '..' 越界,确保最终路径仍在沙箱内
"""
rel = (relative_path or ".").strip()
# 禁止绝对路径,转为相对
if rel.startswith("/"):
rel = rel.lstrip("/")
# 规范化拼接
combined = posixpath.normpath(posixpath.join(sandbox_root, rel))
# 统一尾部斜杠处理,避免边界判断遗漏
root_norm = sandbox_root.rstrip("/")
# 确保仍在沙箱内
if combined != root_norm and not combined.startswith(root_norm + "/"):
raise ValueError("路径越界:仅允许访问沙箱目录内部")
# 禁止路径中出现 '..'(进一步加固)
parts = [p for p in combined.split("/") if p]
if ".." in parts:
raise ValueError("非法路径:不允许使用 '..' 跨目录")
return combined
# —— 重构后的各工具:统一用类型安全的 ctx 与 app_ctx 访问共享资源 ——
@system_mcp.tool()
async def list_files(ctx: Context[ServerSession, SharedAppContext], path: str = ".") -> list[dict[str, Any]]:
"""
列出远程沙箱目录中的文件和子目录(结构化输出)。
Returns:
list[dict]: [{name, path, is_dir, size, permissions, mtime}]
"""
try:
app_ctx = ctx.request_context.lifespan_context
conn = app_ctx.ssh_connection
sandbox_root = app_ctx.sandbox_path
target = _safe_join(sandbox_root, path)
await ctx.info(f"列出目录:{target}")
items: list[dict[str, Any]] = []
async with conn.start_sftp_client() as sftp:
names = await sftp.listdir(target) # list[str]
for name in names:
item_abs = posixpath.join(target, name)
attrs = None
try:
attrs = await sftp.stat(item_abs)
except Exception:
pass
perms = int(attrs.permissions or 0) if attrs else 0
is_dir = bool(pystat.S_ISDIR(perms))
size = int(attrs.size or 0) if attrs and attrs.size is not None else 0
mtime = int(attrs.mtime or 0) if attrs and attrs.mtime is not None else 0
items.append({
"name": name,
"path": posixpath.join(path.rstrip("/"), name) if path != "/" else name,
"is_dir": is_dir,
"size": size,
"permissions": perms,
"mtime": mtime,
})
await ctx.debug(f"目录项数量:{len(items)}")
return items
except FileNotFoundError:
msg = f"目录不存在或不可访问:{path}"
await ctx.error(msg)
return [{"error": msg}]
except Exception as e:
msg = f"list_files 失败:{e}"
await ctx.error(msg)
return [{"error": msg}]
@system_mcp.tool()
async def read_file(ctx: Context[ServerSession, SharedAppContext], file_path: str, encoding: str = "utf-8") -> str:
"""
读取远程沙箱内指定文件内容。
"""
try:
app_ctx = ctx.request_context.lifespan_context
conn = app_ctx.ssh_connection
sandbox_root = app_ctx.sandbox_path
target = _safe_join(sandbox_root, file_path)
await ctx.info(f"读取文件:{target}")
async with conn.start_sftp_client() as sftp:
async with sftp.open(target, "rb") as f:
data = await f.read()
try:
content = data.decode(encoding)
except Exception:
content = data.decode(encoding, errors="replace")
return content
except FileNotFoundError:
msg = f"读取失败:文件不存在 '{file_path}'"
await ctx.error(msg)
return msg
except Exception as e:
msg = f"read_file 失败:{e}"
await ctx.error(msg)
return msg
@system_mcp.tool()
async def write_file(
ctx: Context[ServerSession, SharedAppContext],
file_path: str,
content: str,
encoding: str = "utf-8",
create_parents: bool = True,
) -> dict[str, Any]:
"""
写入远程沙箱文件(默认按需创建父目录)。
"""
try:
app_ctx = ctx.request_context.lifespan_context # 类型化的 SharedAppContext
conn = app_ctx.ssh_connection
sandbox_root = app_ctx.sandbox_path
target = _safe_join(sandbox_root, file_path)
await ctx.info(f"写入文件:{target}")
if create_parents:
parent = posixpath.dirname(target)
if parent and parent != sandbox_root:
await conn.run(f"mkdir -p {shell_quote(parent)}", check=True)
data = content.encode(encoding)
async with conn.start_sftp_client() as sftp:
async with sftp.open(target, "wb") as f:
await f.write(data)
await ctx.debug(f"写入完成:{len(data)} 字节")
return {"path": file_path, "bytes_written": len(data)}
except Exception as e:
msg = f"write_file 失败:{e}"
await ctx.error(msg)
return {"error": msg}
@system_mcp.tool()
async def make_dir(ctx: Context[ServerSession, SharedAppContext], path: str, parents: bool = True) -> str:
try:
app_ctx = ctx.request_context.lifespan_context
conn = app_ctx.ssh_connection
sandbox_root = app_ctx.sandbox_path
target = _safe_join(sandbox_root, path)
if parents:
# 单行命令mkdir -p
await conn.run(f"mkdir -p {shell_quote(target)}", check=True)
else:
async with conn.start_sftp_client() as sftp:
await sftp.mkdir(target)
await ctx.info(f"目录已创建: {target}")
return f"目录已创建: {path}"
except Exception as e:
await ctx.error(f"创建目录失败: {e}")
return f"错误: {e}"
@system_mcp.tool()
async def delete_file(ctx: Context[ServerSession, SharedAppContext], file_path: str) -> str:
"""
删除文件(非目录)。
"""
try:
app_ctx = ctx.request_context.lifespan_context
conn = app_ctx.ssh_connection
sandbox_root = app_ctx.sandbox_path
target = _safe_join(sandbox_root, file_path)
async with conn.start_sftp_client() as sftp:
await sftp.remove(target)
await ctx.info(f"文件已删除: {target}")
return f"文件已删除: {file_path}"
except FileNotFoundError:
msg = f"删除失败:文件不存在 '{file_path}'"
await ctx.error(msg)
return msg
except Exception as e:
await ctx.error(f"删除文件失败: {e}")
return f"错误: {e}"
@system_mcp.tool()
async def delete_dir(
ctx: Context[ServerSession, SharedAppContext],
dir_path: str,
recursive: bool = False,
) -> str:
"""
删除目录。
"""
try:
app_ctx = ctx.request_context.lifespan_context # 类型化的 SharedAppContext
conn = app_ctx.ssh_connection
sandbox_root = app_ctx.sandbox_path
target = _safe_join(sandbox_root, dir_path)
if recursive:
await conn.run(f"rm -rf {shell_quote(target)}", check=True)
else:
async with conn.start_sftp_client() as sftp:
await sftp.rmdir(target)
await ctx.info(f"目录已删除: {target}")
return f"目录已删除: {dir_path}"
except FileNotFoundError:
msg = f"删除失败:目录不存在 '{dir_path}' 或非空"
await ctx.error(msg)
return msg
except Exception as e:
await ctx.error(f"删除目录失败: {e}")
return f"错误: {e}"
@system_mcp.tool()
async def move_path(ctx: Context[ServerSession, SharedAppContext], src: str, dst: str,
overwrite: bool = True) -> str:
try:
app_ctx = ctx.request_context.lifespan_context
conn = app_ctx.ssh_connection
sandbox_root = app_ctx.sandbox_path
src_abs = _safe_join(sandbox_root, src)
dst_abs = _safe_join(sandbox_root, dst)
flag = "-f" if overwrite else ""
# 单行命令mv
cmd = f"mv {flag} {shell_quote(src_abs)} {shell_quote(dst_abs)}".strip()
await conn.run(cmd, check=True)
await ctx.info(f"已移动: {src_abs} -> {dst_abs}")
return f"已移动: {src} -> {dst}"
except FileNotFoundError:
msg = f"移动失败:源不存在 '{src}'"
await ctx.error(msg)
return msg
except Exception as e:
await ctx.error(f"移动失败: {e}")
return f"错误: {e}"
@system_mcp.tool()
async def copy_path(
ctx: Context[ServerSession, SharedAppContext],
src: str,
dst: str,
recursive: bool = True,
overwrite: bool = True,
) -> str:
try:
app_ctx = ctx.request_context.lifespan_context
conn = app_ctx.ssh_connection
sandbox_root = app_ctx.sandbox_path
src_abs = _safe_join(sandbox_root, src)
dst_abs = _safe_join(sandbox_root, dst)
flags = []
if recursive:
flags.append("-r")
if overwrite:
flags.append("-f")
# 单行命令cp
cmd = " ".join(["cp"] + flags + [shell_quote(src_abs), shell_quote(dst_abs)])
await conn.run(cmd, check=True)
await ctx.info(f"已复制: {src_abs} -> {dst_abs}")
return f"已复制: {src} -> {dst}"
except FileNotFoundError:
msg = f"复制失败:源不存在 '{src}'"
await ctx.error(msg)
return msg
except Exception as e:
await ctx.error(f"复制失败: {e}")
return f"错误: {e}"
@system_mcp.tool()
async def exists(ctx: Context[ServerSession, SharedAppContext], path: str) -> bool:
"""
判断路径(文件/目录)是否存在。
"""
try:
app_ctx = ctx.request_context.lifespan_context
conn = app_ctx.ssh_connection
sandbox_root = app_ctx.sandbox_path
target = _safe_join(sandbox_root, path)
async with conn.start_sftp_client() as sftp:
await sftp.stat(target)
return True
except FileNotFoundError:
return False
except Exception as e:
await ctx.error(f"exists 检查失败: {e}")
return False
@system_mcp.tool()
async def stat_path(ctx: Context[ServerSession, SharedAppContext], path: str) -> dict:
"""
查看远程路径属性(结构化输出)。
"""
try:
app_ctx = ctx.request_context.lifespan_context
conn = app_ctx.ssh_connection
sandbox_root = app_ctx.sandbox_path
target = _safe_join(sandbox_root, path)
async with conn.start_sftp_client() as sftp:
attrs = await sftp.stat(target)
perms = attrs.permissions or 0
return {
"path": target,
"size": int(attrs.size or 0),
"is_dir": bool(pystat.S_ISDIR(perms)),
"permissions": perms, # 九进制权限位,例:0o755
"mtime": int(attrs.mtime or 0), # 秒
}
except FileNotFoundError:
msg = f"路径不存在: {path}"
await ctx.error(msg)
return {"error": msg}
except Exception as e:
await ctx.error(f"stat 失败: {e}")
return {"error": str(e)}
return system_mcp

19
mcp/test_tools.py Normal file
View File

@@ -0,0 +1,19 @@
# test_tools.py
from mcp.server.fastmcp import FastMCP
def create_test_mcp() -> FastMCP:
"""创建一个只包含最简单工具的 MCP 实例,用于测试连接。"""
test_mcp = FastMCP(
name="Test Tools",
instructions="用于测试服务器连接是否通畅。",
streamable_http_path="/"
)
@test_mcp.tool()
def ping() -> str:
"""一个简单的工具,用于确认服务器是否响应。"""
return "pong"
return test_mcp

View File

@@ -0,0 +1,185 @@
# topological_analysis_tools.py
import posixpath
import re
from typing import Any, Literal
from contextlib import asynccontextmanager
from collections.abc import AsyncIterator
import asyncssh
from pydantic import BaseModel, Field
# 从 MCP SDK 导入核心组件 [1]
from mcp.server.fastmcp import FastMCP, Context
from mcp.server.session import ServerSession
# 从 lifespan.py 导入共享的配置和数据类 [3]
# 注意: 我们不再需要导入 system_lifespan
from lifespan import (
SharedAppContext,
REMOTE_HOST,
REMOTE_USER,
PRIVATE_KEY_PATH,
INITIAL_WORKING_DIRECTORY,
)
# ==============================================================================
# 1. 辅助函数 (为实现独立性而复制)
# ==============================================================================
def shell_quote(arg: str) -> str:
"""安全地将字符串转义为单个 shell 参数 [5]。"""
return "'" + arg.replace("'", "'\"'\"'") + "'"
def _safe_join(sandbox_root: str, relative_path: str) -> str:
"""安全地将用户提供的相对路径连接到沙箱根目录 [5]。"""
rel = (relative_path or ".").strip()
if rel.startswith("/"):
rel = rel.lstrip("/")
combined = posixpath.normpath(posixpath.join(sandbox_root, rel))
root_norm = sandbox_root.rstrip("/")
if combined != root_norm and not combined.startswith(root_norm + "/"):
raise ValueError("路径越界: 仅允许访问沙箱目录内部")
if ".." in combined.split("/"):
raise ValueError("非法路径: 不允许使用 '..' 跨目录")
return combined
# ==============================================================================
# 2. 独立的生命周期管理器
# ==============================================================================
@asynccontextmanager
async def analysis_lifespan(_server: FastMCP) -> AsyncIterator[SharedAppContext]:
"""
为拓扑分析工具独立管理 SSH 连接的生命周期。
"""
conn: asyncssh.SSHClientConnection | None = None
print("分析工具: 正在建立独立的 SSH 连接...")
try:
conn = await asyncssh.connect(
REMOTE_HOST,
username=REMOTE_USER,
client_keys=[PRIVATE_KEY_PATH],
)
print(f"分析工具: 独立的 SSH 连接到 {REMOTE_HOST} 成功!")
yield SharedAppContext(ssh_connection=conn, sandbox_path=INITIAL_WORKING_DIRECTORY)
finally:
if conn:
conn.close()
await conn.wait_closed()
print("分析工具: 独立的 SSH 连接已关闭。")
# ==============================================================================
# 3. 数据模型与解析函数 (与之前相同)
# ==============================================================================
# (为简洁起见,此处省略了 Pydantic 模型和 _parse_analysis_output 函数的定义,
# 请将上一版本中的这部分代码粘贴到这里)
VALID_CONFIGS = Literal["O.yaml", "S.yaml", "Cl.yaml", "Br.yaml"]
class SubmitTopologicalAnalysisParams(BaseModel):
cif_file_path: str = Field(description="远程服务器上 CIF 文件的相对路径。")
config_files: list[VALID_CONFIGS] = Field(description="选择一个或多个用于分析的 YAML 配置文件。")
output_file_path: str = Field(description="用于保存计算结果的输出文件相对路径。")
class AnalysisResult(BaseModel):
percolation_diameter_a: float | None = Field(None, description="渗透直径 (Percolation diameter), 单位 Å。")
connectivity_distance: float | None = Field(None, description="连通距离 (the minium of d), 单位 Å。")
max_node_length_a: float | None = Field(None, description="最大节点长度 (Maximum node length), 单位 Å。")
min_cation_distance_a: float | None = Field(None, description="到阳离子的最小距离, 单位 Å。")
total_time_s: float | None = Field(None, description="总计算耗时, 单位秒。")
long_node_count: int | None = Field(None, description="长节点数量 (Long node number)。")
short_node_count: int | None = Field(None, description="短节点数量 (Short node number)。")
warnings: list[str] = Field(default_factory=list, description="解析过程中遇到的警告或缺失的关键值信息。")
def _parse_analysis_output(log_content: str) -> AnalysisResult:
def find_float(pattern: str) -> float | None:
match = re.search(pattern, log_content)
return float(match.group(1)) if match else None
def find_int(pattern: str) -> int | None:
match = re.search(pattern, log_content)
return int(match.group(1)) if match else None
data = { "percolation_diameter_a": find_float(r"Percolation diameter \(A\):\s*([\d\.]+)"), "max_node_length_a": find_float(r"Maximum node length detected:\s*([\d\.]+)\s*A"), "min_cation_distance_a": find_float(r"minimum distance to cations is\s*([\d\.]+)\s*A"), "total_time_s": find_float(r"Total used time:\s*([\d\.]+)"), "long_node_count": find_int(r"Long node number:\s*(\d+)"), "short_node_count": find_int(r"Short node number:\s*(\d+)"), }
conn_dist_match = re.search(r"the minium of d\s*([\d\.]+)", log_content, re.MULTILINE)
data["connectivity_distance"] = float(conn_dist_match.group(1)) if conn_dist_match else None
warnings = [k + " 未找到" for k, v in data.items() if v is None and k not in ["min_cation_distance_a", "long_node_count", "short_node_count"]]
return AnalysisResult(**data, warnings=warnings)
# ==============================================================================
# 4. MCP 服务器工厂函数
# ==============================================================================
def create_topological_analysis_mcp() -> FastMCP:
"""
供 Starlette 挂载的工厂函数。
这个函数现在创建并使用自己独立的 lifespan 管理器。
"""
analysis_mcp = FastMCP(
name="Topological Analysis Tools",
instructions="用于在远程服务器上提交和分析拓扑计算任务的工具集。",
streamable_http_path="/",
lifespan=analysis_lifespan, # 关键: 使用本文件内定义的独立 lifespan
)
@analysis_mcp.tool()
async def submit_topological_analysis(
params: SubmitTopologicalAnalysisParams,
ctx: Context[ServerSession, SharedAppContext],
) -> dict[str, Any]:
"""【步骤1/2】异步提交拓扑分析任务。"""
# ... (工具的实现逻辑与之前完全相同)
try:
app_ctx = ctx.request_context.lifespan_context
conn = app_ctx.ssh_connection
sandbox_root = app_ctx.sandbox_path
home_dir = f"/cluster/home/{REMOTE_USER}"
tool_dir = f"{home_dir}/tool/Topological_analyse"
cif_abs_path = _safe_join(sandbox_root, params.cif_file_path)
output_abs_path = _safe_join(sandbox_root, params.output_file_path)
config_args = " ".join([shell_quote(f"{tool_dir}/{cfg}") for cfg in params.config_files])
config_flags = f"-i {config_args}"
command = f"""
nohup sh -c '
source {shell_quote(f"{home_dir}/.bashrc")} && \\
conda activate {shell_quote(f"{home_dir}/anaconda3/envs/zeo")} && \\
python {shell_quote(f"{tool_dir}/analyze_voronoi_nodes.py")} \\
{shell_quote(cif_abs_path)} \\
{config_flags}
' > {shell_quote(output_abs_path)} 2>&1 &
""".strip()
await ctx.info("提交后台拓扑分析任务...")
await conn.run(command, check=True)
return {"success": True, "message": "拓扑分析任务已成功提交。", "output_file": params.output_file_path}
except Exception as e:
msg = f"提交后台任务失败: {e}"
await ctx.error(msg)
return {"success": False, "error": msg}
@analysis_mcp.tool()
async def analyze_topological_results(
output_file_path: str,
ctx: Context[ServerSession, SharedAppContext],
) -> AnalysisResult:
"""【步骤2/2】读取并分析拓扑分析任务的输出文件。"""
# ... (工具的实现逻辑与之前完全相同)
try:
await ctx.info(f"开始分析结果文件: {output_file_path}")
app_ctx = ctx.request_context.lifespan_context
conn = app_ctx.ssh_connection
sandbox_root = app_ctx.sandbox_path
target_path = _safe_join(sandbox_root, output_file_path)
async with conn.start_sftp_client() as sftp:
async with sftp.open(target_path, "r") as f:
log_content = await f.read()
if not log_content:
raise ValueError("结果文件为空。")
analysis_data = _parse_analysis_output(log_content)
return analysis_data
except FileNotFoundError:
msg = f"分析失败: 结果文件 '{output_file_path}' 不存在。"
await ctx.error(msg)
return AnalysisResult(warnings=[msg])
except Exception as e:
msg = f"分析结果时发生错误: {e}"
await ctx.error(msg)
return AnalysisResult(warnings=[msg])
return analysis_mcp

10
mcp/uvicorn/main.py Normal file
View File

@@ -0,0 +1,10 @@
# main.py
from fastapi import FastAPI
app = FastAPI()
@app.get("/")
async def read_root():
return {"message": "Hello, World!"}