Compare commits
13 Commits
f625154aee
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
| 95d719cc1e | |||
| 1f8667ae51 | |||
| c0b2ec5983 | |||
| b9ba79d7a8 | |||
| bd4bd3a645 | |||
| c20d752faa | |||
| 4e21954471 | |||
| fc7716507a | |||
| b11bf2417b | |||
| 9b0f835575 | |||
| b19352382b | |||
| 28639f9cbf | |||
| 41a6038e50 |
2
.idea/misc.xml
generated
2
.idea/misc.xml
generated
@@ -3,5 +3,5 @@
|
|||||||
<component name="Black">
|
<component name="Black">
|
||||||
<option name="sdkName" value="Python 3.12" />
|
<option name="sdkName" value="Python 3.12" />
|
||||||
</component>
|
</component>
|
||||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.12 (test1)" project-jdk-type="Python SDK" />
|
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.11" project-jdk-type="Python SDK" />
|
||||||
</project>
|
</project>
|
||||||
2
.idea/solidstate-tools.iml
generated
2
.idea/solidstate-tools.iml
generated
@@ -2,7 +2,7 @@
|
|||||||
<module type="PYTHON_MODULE" version="4">
|
<module type="PYTHON_MODULE" version="4">
|
||||||
<component name="NewModuleRootManager">
|
<component name="NewModuleRootManager">
|
||||||
<content url="file://$MODULE_DIR$" />
|
<content url="file://$MODULE_DIR$" />
|
||||||
<orderEntry type="jdk" jdkName="Python 3.12 (test1)" jdkType="Python SDK" />
|
<orderEntry type="jdk" jdkName="Python 3.11" jdkType="Python SDK" />
|
||||||
<orderEntry type="sourceFolder" forTests="false" />
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
</component>
|
</component>
|
||||||
</module>
|
</module>
|
||||||
151
contrast learning/copy.py
Normal file
151
contrast learning/copy.py
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def find_element_column_index(cif_lines: list) -> int:
|
||||||
|
"""
|
||||||
|
在CIF文件内容中查找 _atom_site_type_symbol 所在的列索引。
|
||||||
|
|
||||||
|
:param cif_lines: 从CIF文件读取的行列表。
|
||||||
|
:return: 元素符号列的索引(从0开始),如果未找到则返回-1。
|
||||||
|
"""
|
||||||
|
in_loop_header = False
|
||||||
|
column_index = -1
|
||||||
|
current_column = 0
|
||||||
|
|
||||||
|
for line in cif_lines:
|
||||||
|
line_stripped = line.strip()
|
||||||
|
if not line_stripped:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if line_stripped.startswith('loop_'):
|
||||||
|
in_loop_header = True
|
||||||
|
column_index = -1
|
||||||
|
current_column = 0
|
||||||
|
continue
|
||||||
|
|
||||||
|
if in_loop_header:
|
||||||
|
if line_stripped.startswith('_'):
|
||||||
|
if line_stripped.startswith('_atom_site_type_symbol'):
|
||||||
|
column_index = current_column
|
||||||
|
current_column += 1
|
||||||
|
else:
|
||||||
|
# loop_ 头部定义结束,开始数据行
|
||||||
|
return column_index
|
||||||
|
|
||||||
|
return -1 # 如果文件中没有找到 loop_ 或 _atom_site_type_symbol
|
||||||
|
|
||||||
|
|
||||||
|
def copy_cif_with_O_or_S_robust(source_dir: str, target_dir: str, dry_run: bool = False):
|
||||||
|
"""
|
||||||
|
从源文件夹中筛选出内容包含'O'或'S'元素的CIF文件,并复制到目标文件夹。
|
||||||
|
(鲁棒版:能正确解析CIF中的元素符号列)
|
||||||
|
|
||||||
|
:param source_dir: 源文件夹路径,包含CIF文件。
|
||||||
|
:param target_dir: 目标文件夹路径,用于存放筛选出的文件。
|
||||||
|
:param dry_run: 如果为True,则只打印将要复制的文件,而不实际执行复制操作。
|
||||||
|
"""
|
||||||
|
# 1. 路径处理和验证
|
||||||
|
source_path = Path(source_dir)
|
||||||
|
target_path = Path(target_dir)
|
||||||
|
|
||||||
|
if not source_path.is_dir():
|
||||||
|
print(f"错误:源文件夹 '{source_dir}' 不存在或不是一个文件夹。")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not dry_run and not target_path.exists():
|
||||||
|
target_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
print(f"目标文件夹 '{target_dir}' 已创建。")
|
||||||
|
|
||||||
|
print(f"源文件夹: {source_path}")
|
||||||
|
print(f"目标文件夹: {target_path}")
|
||||||
|
if dry_run:
|
||||||
|
print("\n--- *** 模拟运行模式 (Dry Run) *** ---")
|
||||||
|
print("--- 不会执行任何实际的文件复制操作 ---")
|
||||||
|
|
||||||
|
# 2. 开始遍历和筛选
|
||||||
|
print("\n开始扫描源文件夹中的CIF文件...")
|
||||||
|
copied_count = 0
|
||||||
|
checked_files = 0
|
||||||
|
error_files = 0
|
||||||
|
|
||||||
|
# 使用 rglob('*.cif') 可以遍历所有子文件夹,如果只想遍历当前文件夹用 glob
|
||||||
|
for file_path in source_path.glob('*.cif'):
|
||||||
|
if file_path.is_file():
|
||||||
|
checked_files += 1
|
||||||
|
try:
|
||||||
|
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
|
||||||
|
# 步骤 A: 找到元素符号在哪一列
|
||||||
|
element_col_idx = find_element_column_index(lines)
|
||||||
|
|
||||||
|
if element_col_idx == -1:
|
||||||
|
# 在某些CIF文件中,可能没有loop块,而是简单的 key-value 格式
|
||||||
|
# 为了兼容这种情况,我们保留一个简化的检查
|
||||||
|
found_simple = any(
|
||||||
|
line.strip().startswith(('_chemical_formula_sum', '_chemical_formula_structural')) and (
|
||||||
|
' O' in line or ' S' in line) for line in lines)
|
||||||
|
if not found_simple:
|
||||||
|
continue # 如果两种方法都找不到,跳过此文件
|
||||||
|
|
||||||
|
# 步骤 B: 检查该列是否有 'O' 或 'S'
|
||||||
|
found = False
|
||||||
|
for line in lines:
|
||||||
|
line_stripped = line.strip()
|
||||||
|
# 忽略空行、注释行和定义行
|
||||||
|
if not line_stripped or line_stripped.startswith(('#', '_', 'loop_')):
|
||||||
|
continue
|
||||||
|
|
||||||
|
parts = line_stripped.split()
|
||||||
|
# 确保行中有足够的列
|
||||||
|
if len(parts) > element_col_idx:
|
||||||
|
# 元素符号可能带有电荷,如 O2-,所以用 startswith
|
||||||
|
atom_symbol = parts[element_col_idx].strip()
|
||||||
|
if atom_symbol == 'O' or atom_symbol == 'S':
|
||||||
|
found = True
|
||||||
|
break
|
||||||
|
|
||||||
|
# 兼容性检查:如果通过了 found_simple 的检查,也标记为找到
|
||||||
|
if found_simple:
|
||||||
|
found = True
|
||||||
|
|
||||||
|
if found:
|
||||||
|
target_file_path = target_path / file_path.name
|
||||||
|
print(f"找到匹配: '{file_path.name}' (含有 O 或 S 元素)")
|
||||||
|
|
||||||
|
if not dry_run:
|
||||||
|
shutil.copy2(file_path, target_file_path)
|
||||||
|
# print(f" -> 已复制到 {target_file_path}") # 可以取消注释以获得更详细的输出
|
||||||
|
|
||||||
|
copied_count += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_files += 1
|
||||||
|
print(f"!! 处理文件 '{file_path.name}' 时发生错误: {e}")
|
||||||
|
|
||||||
|
# 3. 打印最终报告
|
||||||
|
print("\n--- 操作总结 ---")
|
||||||
|
print(f"共检查了 {checked_files} 个.cif文件。")
|
||||||
|
if error_files > 0:
|
||||||
|
print(f"处理过程中有 {error_files} 个文件发生错误。")
|
||||||
|
if dry_run:
|
||||||
|
print(f"模拟运行结束:如果实际运行,将会有 {copied_count} 个文件被复制。")
|
||||||
|
else:
|
||||||
|
print(f"成功复制了 {copied_count} 个文件到目标文件夹。")
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# !! 重要:请将下面的路径修改为您自己电脑上的实际路径
|
||||||
|
source_folder = "D:/download/2025-10/data_all/input/input"
|
||||||
|
target_folder = "D:/download/2025-10/data_all/output"
|
||||||
|
|
||||||
|
# --- 第一次运行:使用模拟模式 (Dry Run) ---
|
||||||
|
print("================ 第一次运行: 模拟模式 ================")
|
||||||
|
copy_cif_with_O_or_S_robust(source_folder, target_folder, dry_run=True)
|
||||||
|
|
||||||
|
print("\n\n=======================================================")
|
||||||
|
input("检查上面的模拟运行结果。如果符合预期,按回车键继续执行实际复制操作...")
|
||||||
|
print("=======================================================")
|
||||||
|
|
||||||
|
# --- 第二次运行:实际执行复制 ---
|
||||||
|
print("\n================ 第二次运行: 实际复制模式 ================")
|
||||||
|
copy_cif_with_O_or_S_robust(source_folder, target_folder, dry_run=False)
|
||||||
111
contrast learning/delete.py
Normal file
111
contrast learning/delete.py
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def delete_duplicates_from_second_folder(source_dir: str, target_dir: str, dry_run: bool = False):
|
||||||
|
"""
|
||||||
|
删除第二个文件夹中与第一个文件夹内项目同名的文件或文件夹。
|
||||||
|
|
||||||
|
:param source_dir: 第一个文件夹(源)的路径。
|
||||||
|
:param target_dir: 第二个文件夹(目标)的路径,将从此文件夹中删除内容。
|
||||||
|
:param dry_run: 如果为True,则只打印将要删除的内容,而不实际执行删除操作。
|
||||||
|
"""
|
||||||
|
# 1. 将字符串路径转换为Path对象,方便操作
|
||||||
|
source_path = Path(source_dir)
|
||||||
|
target_path = Path(target_dir)
|
||||||
|
|
||||||
|
# 2. 验证路径是否存在且为文件夹
|
||||||
|
if not source_path.is_dir():
|
||||||
|
print(f"错误:源文件夹 '{source_dir}' 不存在或不是一个文件夹。")
|
||||||
|
return
|
||||||
|
if not target_path.is_dir():
|
||||||
|
print(f"错误:目标文件夹 '{target_dir}' 不存在或不是一个文件夹。")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"源文件夹: {source_path}")
|
||||||
|
print(f"目标文件夹: {target_path}")
|
||||||
|
if dry_run:
|
||||||
|
print("\n--- *** 模拟运行模式 (Dry Run) *** ---")
|
||||||
|
print("--- 不会执行任何实际的删除操作 ---")
|
||||||
|
|
||||||
|
# 3. 获取源文件夹中所有项目(文件和子文件夹)的名称
|
||||||
|
# p.name 会返回路径的最后一部分,即文件名或文件夹名
|
||||||
|
source_item_names = {p.name for p in source_path.iterdir()}
|
||||||
|
|
||||||
|
if not source_item_names:
|
||||||
|
print("\n源文件夹为空,无需执行任何操作。")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"\n在源文件夹中找到 {len(source_item_names)} 个项目。")
|
||||||
|
print("开始检查并删除目标文件夹中的同名项目...")
|
||||||
|
|
||||||
|
deleted_count = 0
|
||||||
|
# 4. 遍历源文件夹中的项目名称
|
||||||
|
for item_name in source_item_names:
|
||||||
|
# 构建目标文件夹中可能存在的同名项目的完整路径
|
||||||
|
item_to_delete = target_path / item_name
|
||||||
|
|
||||||
|
# 5. 检查该项目是否存在于目标文件夹中
|
||||||
|
if item_to_delete.exists():
|
||||||
|
try:
|
||||||
|
if item_to_delete.is_file():
|
||||||
|
# 如果是文件,直接删除
|
||||||
|
print(f"准备删除文件: {item_to_delete}")
|
||||||
|
if not dry_run:
|
||||||
|
item_to_delete.unlink()
|
||||||
|
print(" -> 已删除。")
|
||||||
|
deleted_count += 1
|
||||||
|
|
||||||
|
elif item_to_delete.is_dir():
|
||||||
|
# 如果是文件夹,使用 shutil.rmtree 删除整个文件夹及其内容
|
||||||
|
print(f"准备删除文件夹及其所有内容: {item_to_delete}")
|
||||||
|
if not dry_run:
|
||||||
|
shutil.rmtree(item_to_delete)
|
||||||
|
print(" -> 已删除。")
|
||||||
|
deleted_count += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"!! 删除 '{item_to_delete}' 时发生错误: {e}")
|
||||||
|
|
||||||
|
if deleted_count == 0:
|
||||||
|
print("\n操作完成:在目标文件夹中没有找到需要删除的同名项目。")
|
||||||
|
else:
|
||||||
|
if dry_run:
|
||||||
|
print(f"\n模拟运行结束:如果实际运行,将会有 {deleted_count} 个项目被删除。")
|
||||||
|
else:
|
||||||
|
print(f"\n操作完成:总共删除了 {deleted_count} 个项目。")
|
||||||
|
|
||||||
|
|
||||||
|
# --- 使用示例 ---
|
||||||
|
|
||||||
|
# 在运行前,请创建以下文件夹和文件结构进行测试:
|
||||||
|
# /your/path/folder1/
|
||||||
|
# ├── file_a.txt
|
||||||
|
# ├── file_b.log
|
||||||
|
# └── subfolder_x/
|
||||||
|
# └── test.txt
|
||||||
|
|
||||||
|
# /your/path/folder2/
|
||||||
|
# ├── file_a.txt (将被删除)
|
||||||
|
# ├── file_c.md
|
||||||
|
# └── subfolder_x/ (将被删除)
|
||||||
|
# └── another.txt
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# !! 重要:请将下面的路径修改为您自己电脑上的实际路径
|
||||||
|
folder1_path = "D:/download/2025-10/after_step5/after_step5/S" # 源文件夹
|
||||||
|
folder2_path = "D:/download/2025-10/input/input" # 目标文件夹
|
||||||
|
|
||||||
|
# --- 第一次运行:使用模拟模式 (Dry Run),非常推荐!---
|
||||||
|
# 这会告诉你脚本将要做什么,但不会真的删除任何东西。
|
||||||
|
print("================ 第一次运行: 模拟模式 ================")
|
||||||
|
delete_duplicates_from_second_folder(folder1_path, folder2_path, dry_run=True)
|
||||||
|
|
||||||
|
print("\n\n=======================================================")
|
||||||
|
input("检查上面的模拟运行结果。如果符合预期,按回车键继续执行实际删除操作...")
|
||||||
|
print("=======================================================")
|
||||||
|
|
||||||
|
# --- 第二次运行:实际执行删除 ---
|
||||||
|
# 确认模拟运行结果无误后,再将 dry_run 设置为 False 或移除该参数。
|
||||||
|
print("\n================ 第二次运行: 实际删除模式 ================")
|
||||||
|
delete_duplicates_from_second_folder(folder1_path, folder2_path, dry_run=False)
|
||||||
50
dpgen/data/P3ma/model1.cif
Normal file
50
dpgen/data/P3ma/model1.cif
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
#==============================================================================
|
||||||
|
# CRYSTAL DATA
|
||||||
|
#------------------------------------------------------------------------------
|
||||||
|
# Created from data in Table S9 for ball milled and 1h annealed Li3YCl6
|
||||||
|
# Source: Image provided by user
|
||||||
|
#==============================================================================
|
||||||
|
|
||||||
|
data_Li3YCl6_annealed
|
||||||
|
_audit_creation_method 'Generated from published data table'
|
||||||
|
|
||||||
|
#------------------------------------------------------------------------------
|
||||||
|
# CELL PARAMETERS
|
||||||
|
#------------------------------------------------------------------------------
|
||||||
|
_cell_length_a 11.2001(2)
|
||||||
|
_cell_length_b 11.2001(2)
|
||||||
|
_cell_length_c 6.0441(2)
|
||||||
|
_cell_angle_alpha 90.0
|
||||||
|
_cell_angle_beta 90.0
|
||||||
|
_cell_angle_gamma 120.0
|
||||||
|
_cell_volume 656.39
|
||||||
|
|
||||||
|
#------------------------------------------------------------------------------
|
||||||
|
# SYMMETRY
|
||||||
|
#------------------------------------------------------------------------------
|
||||||
|
_symmetry_space_group_name_H-M 'P -3 m 1'
|
||||||
|
_symmetry_Int_Tables_number 164
|
||||||
|
|
||||||
|
#------------------------------------------------------------------------------
|
||||||
|
# ATOMIC COORDINATES AND DISPLACEMENT PARAMETERS
|
||||||
|
#------------------------------------------------------------------------------
|
||||||
|
loop_
|
||||||
|
_atom_site_label
|
||||||
|
_atom_site_type_symbol
|
||||||
|
_atom_site_Wyckoff_symbol
|
||||||
|
_atom_site_fract_x
|
||||||
|
_atom_site_fract_y
|
||||||
|
_atom_site_fract_z
|
||||||
|
_atom_site_occupancy
|
||||||
|
_atom_site_B_iso_or_equiv
|
||||||
|
Y1 Y 1a 0.0000 0.0000 0.0000 1.0000 1.10(3)
|
||||||
|
Y2 Y 2d 0.3333 0.6666 -0.065(3) 1.0000 1.10(3)
|
||||||
|
Cl1 Cl 6i 0.1131(7) -0.1131(7) 0.7717(8) 1.0000 1.99(3)
|
||||||
|
Cl2 Cl 6i 0.2182(7) -0.2182(7) 0.2606(8) 1.0000 1.99(3)
|
||||||
|
Cl3 Cl 6i 0.4436(4) -0.4436(4) 0.7604(8) 1.0000 1.99(3)
|
||||||
|
Li1 Li 6g 0.3397 0.3397 0.0000 1.0000 5.00
|
||||||
|
Li2 Li 6h 0.3397 0.3397 0.5000 0.5000 5.00
|
||||||
|
|
||||||
|
#==============================================================================
|
||||||
|
# END OF DATA
|
||||||
|
#==============================================================================
|
||||||
50
dpgen/data/P3ma/model2.cif
Normal file
50
dpgen/data/P3ma/model2.cif
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
#==============================================================================
|
||||||
|
# CRYSTAL DATA
|
||||||
|
#------------------------------------------------------------------------------
|
||||||
|
# Created from data in Table S9 for ball milled and 1h annealed Li3YCl6
|
||||||
|
# Source: Image provided by user
|
||||||
|
#==============================================================================
|
||||||
|
|
||||||
|
data_Li3YCl6_annealed
|
||||||
|
_audit_creation_method 'Generated from published data table'
|
||||||
|
|
||||||
|
#------------------------------------------------------------------------------
|
||||||
|
# CELL PARAMETERS
|
||||||
|
#------------------------------------------------------------------------------
|
||||||
|
_cell_length_a 11.2001(2)
|
||||||
|
_cell_length_b 11.2001(2)
|
||||||
|
_cell_length_c 6.0441(2)
|
||||||
|
_cell_angle_alpha 90.0
|
||||||
|
_cell_angle_beta 90.0
|
||||||
|
_cell_angle_gamma 120.0
|
||||||
|
_cell_volume 656.39
|
||||||
|
|
||||||
|
#------------------------------------------------------------------------------
|
||||||
|
# SYMMETRY
|
||||||
|
#------------------------------------------------------------------------------
|
||||||
|
_symmetry_space_group_name_H-M 'P -3 m 1'
|
||||||
|
_symmetry_Int_Tables_number 164
|
||||||
|
|
||||||
|
#------------------------------------------------------------------------------
|
||||||
|
# ATOMIC COORDINATES AND DISPLACEMENT PARAMETERS
|
||||||
|
#------------------------------------------------------------------------------
|
||||||
|
loop_
|
||||||
|
_atom_site_label
|
||||||
|
_atom_site_type_symbol
|
||||||
|
_atom_site_Wyckoff_symbol
|
||||||
|
_atom_site_fract_x
|
||||||
|
_atom_site_fract_y
|
||||||
|
_atom_site_fract_z
|
||||||
|
_atom_site_occupancy
|
||||||
|
_atom_site_B_iso_or_equiv
|
||||||
|
Y1 Y 1a 0.0000 0.0000 0.0000 1.0000 1.10(3)
|
||||||
|
Y2 Y 2d 0.3333 0.6666 0.488(1) 1.0000 1.10(3)
|
||||||
|
Cl1 Cl 6i 0.1131(7) -0.1131(7) 0.7717(8) 1.0000 1.99(3)
|
||||||
|
Cl2 Cl 6i 0.2182(7) -0.2182(7) 0.2606(8) 1.0000 1.99(3)
|
||||||
|
Cl3 Cl 6i 0.4436(4) -0.4436(4) 0.7604(8) 1.0000 1.99(3)
|
||||||
|
Li1 Li 6g 0.3397 0.3397 0.0000 1.0000 5.00
|
||||||
|
Li2 Li 6h 0.3397 0.3397 0.5000 0.5000 5.00
|
||||||
|
|
||||||
|
#==============================================================================
|
||||||
|
# END OF DATA
|
||||||
|
#==============================================================================
|
||||||
51
dpgen/data/P3ma/model3.cif
Normal file
51
dpgen/data/P3ma/model3.cif
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
#==============================================================================
|
||||||
|
# CRYSTAL DATA
|
||||||
|
#------------------------------------------------------------------------------
|
||||||
|
# Created from data in Table S9 for ball milled and 1h annealed Li3YCl6
|
||||||
|
# Source: Image provided by user
|
||||||
|
#==============================================================================
|
||||||
|
|
||||||
|
data_Li3YCl6_annealed
|
||||||
|
_audit_creation_method 'Generated from published data table'
|
||||||
|
|
||||||
|
#------------------------------------------------------------------------------
|
||||||
|
# CELL PARAMETERS
|
||||||
|
#------------------------------------------------------------------------------
|
||||||
|
_cell_length_a 11.2001(2)
|
||||||
|
_cell_length_b 11.2001(2)
|
||||||
|
_cell_length_c 6.0441(2)
|
||||||
|
_cell_angle_alpha 90.0
|
||||||
|
_cell_angle_beta 90.0
|
||||||
|
_cell_angle_gamma 120.0
|
||||||
|
_cell_volume 656.39
|
||||||
|
|
||||||
|
#------------------------------------------------------------------------------
|
||||||
|
# SYMMETRY
|
||||||
|
#------------------------------------------------------------------------------
|
||||||
|
_symmetry_space_group_name_H-M 'P -3 m 1'
|
||||||
|
_symmetry_Int_Tables_number 164
|
||||||
|
|
||||||
|
#------------------------------------------------------------------------------
|
||||||
|
# ATOMIC COORDINATES AND DISPLACEMENT PARAMETERS
|
||||||
|
#------------------------------------------------------------------------------
|
||||||
|
loop_
|
||||||
|
_atom_site_label
|
||||||
|
_atom_site_type_symbol
|
||||||
|
_atom_site_Wyckoff_symbol
|
||||||
|
_atom_site_fract_x
|
||||||
|
_atom_site_fract_y
|
||||||
|
_atom_site_fract_z
|
||||||
|
_atom_site_occupancy
|
||||||
|
_atom_site_B_iso_or_equiv
|
||||||
|
Y1 Y 1a 0.0000 0.0000 0.0000 1.0000 1.10(3)
|
||||||
|
Y2 Y 2d 0.3333 0.6666 0.488(1) 0.7500 1.10(3)
|
||||||
|
Y3 Y 2d 0.3333 0.6666 -0.065(3) 0.2500 1.10(3)
|
||||||
|
Cl1 Cl 6i 0.1131(7) -0.1131(7) 0.7717(8) 1.0000 1.99(3)
|
||||||
|
Cl2 Cl 6i 0.2182(7) -0.2182(7) 0.2606(8) 1.0000 1.99(3)
|
||||||
|
Cl3 Cl 6i 0.4436(4) -0.4436(4) 0.7604(8) 1.0000 1.99(3)
|
||||||
|
Li1 Li 6g 0.3397 0.3397 0.0000 1.0000 5.00
|
||||||
|
Li2 Li 6h 0.3397 0.3397 0.5000 0.5000 5.00
|
||||||
|
|
||||||
|
#==============================================================================
|
||||||
|
# END OF DATA
|
||||||
|
#==============================================================================
|
||||||
60
dpgen/data/P3ma/model4.cif
Normal file
60
dpgen/data/P3ma/model4.cif
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
#==============================================================================
|
||||||
|
# CRYSTAL DATA
|
||||||
|
#------------------------------------------------------------------------------
|
||||||
|
# Created from data in Table S9 for ball milled and 1h annealed Li3YCl6
|
||||||
|
# Source: Image provided by user
|
||||||
|
#==============================================================================
|
||||||
|
|
||||||
|
data_Li3YCl6_annealed
|
||||||
|
_audit_creation_method 'Generated from published data table'
|
||||||
|
|
||||||
|
#------------------------------------------------------------------------------
|
||||||
|
# CELL PARAMETERS
|
||||||
|
#------------------------------------------------------------------------------
|
||||||
|
_cell_length_a 11.2001(2)
|
||||||
|
_cell_length_b 11.2001(2)
|
||||||
|
_cell_length_c 6.0441(2)
|
||||||
|
_cell_angle_alpha 90.0
|
||||||
|
_cell_angle_beta 90.0
|
||||||
|
_cell_angle_gamma 120.0
|
||||||
|
_cell_volume 656.39
|
||||||
|
|
||||||
|
#------------------------------------------------------------------------------
|
||||||
|
# SYMMETRY
|
||||||
|
#------------------------------------------------------------------------------
|
||||||
|
_symmetry_space_group_name_H-M 'P -3 m 1'
|
||||||
|
_symmetry_Int_Tables_number 164
|
||||||
|
|
||||||
|
#------------------------------------------------------------------------------
|
||||||
|
# ATOMIC COORDINATES AND DISPLACEMENT PARAMETERS
|
||||||
|
#------------------------------------------------------------------------------
|
||||||
|
loop_
|
||||||
|
_atom_site_label
|
||||||
|
_atom_site_type_symbol
|
||||||
|
_atom_site_fract_x
|
||||||
|
_atom_site_fract_y
|
||||||
|
_atom_site_fract_z
|
||||||
|
_atom_site_occupancy
|
||||||
|
|
||||||
|
# --- Y位点 (基于Model 3, Occ: M1=1, M2=0.75, M3=0.25) ---
|
||||||
|
Y1_main Y 0.0000 0.0000 0.0000 0.9444 # 1a位
|
||||||
|
Li_on_Y1 Li 0.0000 0.0000 0.0000 0.0556
|
||||||
|
Y2_main Y 0.3333 0.6666 0.4880 0.7083 # 2d位
|
||||||
|
Li_on_Y2 Li 0.3333 0.6666 0.4880 0.0417
|
||||||
|
Y3_main Y 0.3333 0.6666 -0.0650 0.2361 # 2d位
|
||||||
|
Li_on_Y3 Li 0.3333 0.6666 -0.0650 0.0139
|
||||||
|
|
||||||
|
# --- Li位点 (基于标准结构, Occ: 6g=1, 6h=0.5) ---
|
||||||
|
Li1_main Li 0.3397 0.3397 0.0000 0.9815 # 6g位
|
||||||
|
Y_on_Li1 Y 0.3397 0.3397 0.0000 0.0185
|
||||||
|
Li2_main Li 0.3397 0.3397 0.5000 0.4907 # 6h位
|
||||||
|
Y_on_Li2 Y 0.3397 0.3397 0.5000 0.0093
|
||||||
|
|
||||||
|
# --- Cl位点 (保持不变) ---
|
||||||
|
Cl1 Cl 0.1131 -0.1131 0.7717 1.0000
|
||||||
|
Cl2 Cl 0.2182 -0.2182 0.2606 1.0000
|
||||||
|
Cl3 Cl 0.4436 -0.4436 0.7604 1.0000
|
||||||
|
|
||||||
|
#==============================================================================
|
||||||
|
# END OF DATA
|
||||||
|
#==============================================================================
|
||||||
51
dpgen/data/P3ma/origin.cif
Normal file
51
dpgen/data/P3ma/origin.cif
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
#==============================================================================
|
||||||
|
# CRYSTAL DATA
|
||||||
|
#------------------------------------------------------------------------------
|
||||||
|
# Created from data in Table S9 for ball milled and 1h annealed Li3YCl6
|
||||||
|
# Source: Image provided by user
|
||||||
|
#==============================================================================
|
||||||
|
|
||||||
|
data_Li3YCl6_annealed
|
||||||
|
_audit_creation_method 'Generated from published data table'
|
||||||
|
|
||||||
|
#------------------------------------------------------------------------------
|
||||||
|
# CELL PARAMETERS
|
||||||
|
#------------------------------------------------------------------------------
|
||||||
|
_cell_length_a 11.2001(2)
|
||||||
|
_cell_length_b 11.2001(2)
|
||||||
|
_cell_length_c 6.0441(2)
|
||||||
|
_cell_angle_alpha 90.0
|
||||||
|
_cell_angle_beta 90.0
|
||||||
|
_cell_angle_gamma 120.0
|
||||||
|
_cell_volume 656.39
|
||||||
|
|
||||||
|
#------------------------------------------------------------------------------
|
||||||
|
# SYMMETRY
|
||||||
|
#------------------------------------------------------------------------------
|
||||||
|
_symmetry_space_group_name_H-M 'P -3 m 1'
|
||||||
|
_symmetry_Int_Tables_number 164
|
||||||
|
|
||||||
|
#------------------------------------------------------------------------------
|
||||||
|
# ATOMIC COORDINATES AND DISPLACEMENT PARAMETERS
|
||||||
|
#------------------------------------------------------------------------------
|
||||||
|
loop_
|
||||||
|
_atom_site_label
|
||||||
|
_atom_site_type_symbol
|
||||||
|
_atom_site_Wyckoff_symbol
|
||||||
|
_atom_site_fract_x
|
||||||
|
_atom_site_fract_y
|
||||||
|
_atom_site_fract_z
|
||||||
|
_atom_site_occupancy
|
||||||
|
_atom_site_B_iso_or_equiv
|
||||||
|
Y1 Y 1a 0.0000 0.0000 0.0000 1.0000 1.10(3)
|
||||||
|
Y2 Y 2d 0.3333 0.6666 0.488(1) 0.8269 1.10(3)
|
||||||
|
Y3 Y 2d 0.3333 0.6666 -0.065(3) 0.1730 1.10(3)
|
||||||
|
Cl1 Cl 6i 0.1131(7) -0.1131(7) 0.7717(8) 1.0000 1.99(3)
|
||||||
|
Cl2 Cl 6i 0.2182(7) -0.2182(7) 0.2606(8) 1.0000 1.99(3)
|
||||||
|
Cl3 Cl 6i 0.4436(4) -0.4436(4) 0.7604(8) 1.0000 1.99(3)
|
||||||
|
Li1 Li 6g 0.3397 0.3397 0.0000 1.0000 5.00
|
||||||
|
Li2 Li 6h 0.3397 0.3397 0.5000 0.5000 5.00
|
||||||
|
|
||||||
|
#==============================================================================
|
||||||
|
# END OF DATA
|
||||||
|
#==============================================================================
|
||||||
53
dpgen/data/Pnma/origin_backup.cif
Normal file
53
dpgen/data/Pnma/origin_backup.cif
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
#------------------------------------------------------------------------------
|
||||||
|
# CIF (Crystallographic Information File) for Li3YCl6
|
||||||
|
# Data source: Table S1 from the provided image.
|
||||||
|
# Rietveld refinement result of the neutron diffraction pattern for the 450 °C-annealed sample.
|
||||||
|
#------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
data_Li3YCl6
|
||||||
|
|
||||||
|
_chemical_name_systematic 'Lithium Yttrium Chloride'
|
||||||
|
_chemical_formula_sum 'Li3 Y1 Cl6'
|
||||||
|
_chemical_formula_structural 'Li3YCl6'
|
||||||
|
|
||||||
|
_symmetry_space_group_name_H-M 'P n m a'
|
||||||
|
_symmetry_Int_Tables_number 62
|
||||||
|
_symmetry_cell_setting orthorhombic
|
||||||
|
|
||||||
|
loop_
|
||||||
|
_symmetry_equiv_pos_as_xyz
|
||||||
|
'x, y, z'
|
||||||
|
'-x+1/2, y+1/2, -z+1/2'
|
||||||
|
'-x, -y, -z'
|
||||||
|
'x+1/2, -y+1/2, z+1/2'
|
||||||
|
'-x, y+1/2, -z'
|
||||||
|
'x-1/2, -y-1/2, z-1/2'
|
||||||
|
'x, -y, z'
|
||||||
|
'-x-1/2, y-1/2, -z-1/2'
|
||||||
|
|
||||||
|
_cell_length_a 12.92765(13)
|
||||||
|
_cell_length_b 11.19444(10)
|
||||||
|
_cell_length_c 6.04000(12)
|
||||||
|
_cell_angle_alpha 90.0
|
||||||
|
_cell_angle_beta 90.0
|
||||||
|
_cell_angle_gamma 90.0
|
||||||
|
_cell_volume 874.15
|
||||||
|
_cell_formula_units_Z 4
|
||||||
|
|
||||||
|
loop_
|
||||||
|
_atom_site_label
|
||||||
|
_atom_site_type_symbol
|
||||||
|
_atom_site_fract_x
|
||||||
|
_atom_site_fract_y
|
||||||
|
_atom_site_fract_z
|
||||||
|
_atom_site_occupancy
|
||||||
|
_atom_site_Wyckoff_symbol
|
||||||
|
_atom_site_U_iso_or_equiv
|
||||||
|
Li1 Li 0.11730(7) 0.09640(7) 0.04860(10) 0.750(13) 8d 4.579(2)
|
||||||
|
Li2 Li 0.13270(9) 0.07900(10) 0.48600(2) 0.750(19) 8d 9.554(4)
|
||||||
|
Cl1 Cl 0.21726(7) 0.58920(7) 0.26362(11) 1.0 8d 0.797(17)
|
||||||
|
Cl2 Cl 0.45948(8) 0.08259(8) 0.23831(13) 1.0 8d 1.548(2)
|
||||||
|
Cl3 Cl 0.04505(10) 0.25000 0.74110(2) 1.0 4c 1.848(3)
|
||||||
|
Cl4 Cl 0.20205(9) 0.25000 0.24970(2) 1.0 4c 0.561(2)
|
||||||
|
Y1 Y 0.37529(10) 0.25000 0.01870(3) 1.0 4c 1.121(17)
|
||||||
|
#------------------------------------------------------------------------------
|
||||||
151
dpgen/plus.py
Normal file
151
dpgen/plus.py
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
import random
|
||||||
|
from typing import List
|
||||||
|
from pymatgen.core import Structure
|
||||||
|
from pymatgen.io.vasp import Poscar
|
||||||
|
|
||||||
|
def _is_close_frac(z, target, tol=2e-2):
|
||||||
|
t = target % 1.0
|
||||||
|
return min(abs(z - t), abs(z - (t + 1)), abs(z - (t - 1))) < tol
|
||||||
|
|
||||||
|
def make_model3_poscar_from_cif(cif_path: str,
|
||||||
|
out_poscar: str = "POSCAR_model3_supercell",
|
||||||
|
seed: int = 42,
|
||||||
|
tol: float = 2e-2):
|
||||||
|
"""
|
||||||
|
将 model3.cif 扩胞为 [[3,0,0],[2,4,0],[0,0,6]] 的2160原子超胞,并把部分占据位点(Y2=0.75, Y3=0.25, Li2=0.5)
|
||||||
|
显式有序化后写出 POSCAR。
|
||||||
|
"""
|
||||||
|
random.seed(seed)
|
||||||
|
|
||||||
|
# 1) 读取 CIF
|
||||||
|
s = Structure.from_file(cif_path)
|
||||||
|
|
||||||
|
# 2) 扩胞(a_s=3a0, b_s=2a0+4b0, c_s=6c0)[1]
|
||||||
|
T = [[3, 0, 0],
|
||||||
|
[2, 4, 0],
|
||||||
|
[0, 0, 6]]
|
||||||
|
s.make_supercell(T)
|
||||||
|
|
||||||
|
# 3) 识别三类需取整的位点:Y2、Y3、Li2
|
||||||
|
y2_idx: List[int] = []
|
||||||
|
y3_idx: List[int] = []
|
||||||
|
li2_idx: List[int] = []
|
||||||
|
|
||||||
|
for i, site in enumerate(s.sites):
|
||||||
|
# 兼容不同版本pymatgen
|
||||||
|
try:
|
||||||
|
el = site.species.elements[0].symbol
|
||||||
|
except Exception:
|
||||||
|
ss = site.species_string
|
||||||
|
el = "Li" if ss.startswith("Li") else ("Y" if ss.startswith("Y") else ("Cl" if ss.startswith("Cl") else ss))
|
||||||
|
z = site.frac_coords[2]
|
||||||
|
if el == "Y":
|
||||||
|
if _is_close_frac(z, 0.488, tol):
|
||||||
|
y2_idx.append(i)
|
||||||
|
elif _is_close_frac(z, -0.065, tol) or _is_close_frac(z, 0.935, tol):
|
||||||
|
y3_idx.append(i)
|
||||||
|
elif el == "Li":
|
||||||
|
if _is_close_frac(z, 0.5, tol):
|
||||||
|
li2_idx.append(i)
|
||||||
|
|
||||||
|
def choose_keep(idxs, frac_keep):
|
||||||
|
n = len(idxs)
|
||||||
|
k = int(round(n * frac_keep))
|
||||||
|
if k < 0: k = 0
|
||||||
|
if k > n: k = n
|
||||||
|
keep = set(random.sample(idxs, k)) if 0 < k < n else set(idxs if k == n else [])
|
||||||
|
drop = [i for i in idxs if i not in keep]
|
||||||
|
return keep, drop
|
||||||
|
|
||||||
|
keep_y2, drop_y2 = choose_keep(y2_idx, 0.75)
|
||||||
|
keep_y3, drop_y3 = choose_keep(y3_idx, 0.25)
|
||||||
|
keep_li2, drop_li2 = choose_keep(li2_idx, 0.50)
|
||||||
|
|
||||||
|
# 4) 保留者占据设为1,其余删除
|
||||||
|
for i in keep_y2 | keep_y3:
|
||||||
|
s.replace(i, "Y")
|
||||||
|
for i in keep_li2:
|
||||||
|
s.replace(i, "Li")
|
||||||
|
to_remove = sorted(drop_y2 + drop_y3 + drop_li2, reverse=True)
|
||||||
|
for i in to_remove:
|
||||||
|
s.remove_sites([i])
|
||||||
|
|
||||||
|
# 5) 最终清理:消除任何残留的部分占据(防止 POSCAR 写出报错)
|
||||||
|
# 若有 site.is_ordered==False,则取该站位的“主要元素”替换为占据=1
|
||||||
|
for i, site in enumerate(s.sites):
|
||||||
|
if not site.is_ordered:
|
||||||
|
d = site.species.as_dict() # {'Li': 0.5} 或 {'Li':0.5,'Y':0.5}
|
||||||
|
elem = max(d.items(), key=lambda kv: kv[1])[0]
|
||||||
|
s.replace(i, elem)
|
||||||
|
|
||||||
|
# 6) 排序并写出 POSCAR
|
||||||
|
order = {"Li": 0, "Y": 1, "Cl": 2}
|
||||||
|
s = s.get_sorted_structure(key=lambda site: order.get(site.species.elements[0].symbol, 99))
|
||||||
|
Poscar(s).write_file(out_poscar)
|
||||||
|
|
||||||
|
# 报告
|
||||||
|
comp = {k: int(v) for k, v in s.composition.as_dict().items()}
|
||||||
|
print(f"写出 {out_poscar};总原子数 = {len(s)}")
|
||||||
|
print(f"Y2识别={len(y2_idx)},Y3识别={len(y3_idx)},Li2识别={len(li2_idx)};组成={comp}")
|
||||||
|
|
||||||
|
import random
|
||||||
|
from typing import List
|
||||||
|
from pymatgen.core import Structure
|
||||||
|
from pymatgen.io.vasp import Poscar
|
||||||
|
|
||||||
|
def make_pnma_poscar_from_cif(cif_path: str,
|
||||||
|
out_poscar: str = "POSCAR_pnma_supercell",
|
||||||
|
seed: int = 42,
|
||||||
|
supercell=(3,3,6),
|
||||||
|
tol: float = 1e-6):
|
||||||
|
"""
|
||||||
|
读取 Pnma 的 CIF(如 origin.cif),扩胞到 2160 原子,并把部分占据的 Li 位点(0.75)显式取整后写出 POSCAR。
|
||||||
|
默认超胞尺度为(3,3,6),体积放大因子=54,40原子/原胞×54=2160 [1][3]。
|
||||||
|
"""
|
||||||
|
random.seed(seed)
|
||||||
|
|
||||||
|
s = Structure.from_file(cif_path)
|
||||||
|
|
||||||
|
# 扩胞;Pnma原胞已是正交,直接用对角放缩
|
||||||
|
s.make_supercell(supercell)
|
||||||
|
|
||||||
|
# 找出所有“部分占据的 Li”位点
|
||||||
|
partial_li_idx: List[int] = []
|
||||||
|
for i, site in enumerate(s.sites):
|
||||||
|
if not site.is_ordered:
|
||||||
|
d = site.species.as_dict() # 例如 {'Li': 0.75}
|
||||||
|
# 只处理主要元素是Li且占据<1的位点
|
||||||
|
m_elem, m_occ = max(d.items(), key=lambda kv: kv[1])
|
||||||
|
if m_elem == "Li" and m_occ < 1 - tol:
|
||||||
|
partial_li_idx.append(i)
|
||||||
|
|
||||||
|
# 以占据0.75进行随机取整:保留75%,其余删除为“空位”
|
||||||
|
n = len(partial_li_idx)
|
||||||
|
k = int(round(n * 0.75))
|
||||||
|
keep = set(random.sample(partial_li_idx, k)) if 0 < k < n else set(partial_li_idx if k == n else [])
|
||||||
|
drop = sorted([i for i in partial_li_idx if i not in keep], reverse=True)
|
||||||
|
|
||||||
|
# 保留者设为占据=1;删除其余
|
||||||
|
for i in keep:
|
||||||
|
s.replace(i, "Li")
|
||||||
|
for i in drop:
|
||||||
|
s.remove_sites([i])
|
||||||
|
|
||||||
|
# 兜底:若仍有部分占据,强制取主要元素
|
||||||
|
for i, site in enumerate(s.sites):
|
||||||
|
if not site.is_ordered:
|
||||||
|
d = site.species.as_dict()
|
||||||
|
elem = max(d.items(), key=lambda kv: kv[1])[0]
|
||||||
|
s.replace(i, elem)
|
||||||
|
|
||||||
|
# 排序并写POSCAR
|
||||||
|
order = {"Li": 0, "Y": 1, "Cl": 2}
|
||||||
|
s = s.get_sorted_structure(key=lambda site: order.get(site.species.elements[0].symbol, 99))
|
||||||
|
Poscar(s).write_file(out_poscar)
|
||||||
|
|
||||||
|
comp = {k: int(v) for k, v in s.composition.as_dict().items()}
|
||||||
|
print(f"写出 {out_poscar};总原子数 = {len(s)};组成 = {comp}")
|
||||||
|
|
||||||
|
if __name__=="__main__":
|
||||||
|
# make_model3_poscar_from_cif("data/P3ma/model3.cif","data/P3ma/supercell_model4.poscar")
|
||||||
|
make_pnma_poscar_from_cif("data/Pnma/origin.cif","data/Pnma/supercell_pnma.poscar",seed=42)
|
||||||
0
dpgen/supercell_make_p3ma.py
Normal file
0
dpgen/supercell_make_p3ma.py
Normal file
76
dpgen/transport.py
Normal file
76
dpgen/transport.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
from typing import List, Dict, Tuple
|
||||||
|
|
||||||
|
def add_li_y_antisite_average(
|
||||||
|
x: float,
|
||||||
|
# Y 位点:[(label, wyckoff, (fx,fy,fz), occ, multiplicity)]
|
||||||
|
y_sites: List[Tuple[str, str, Tuple[float, float, float], float, int]] = None,
|
||||||
|
# Li 位点:[(label, wyckoff, (fx,fy,fz), occ, multiplicity)]
|
||||||
|
li_sites: List[Tuple[str, str, Tuple[float, float, float], float, int]] = None,
|
||||||
|
# 可选:各位点的 B 因子(字典),不需要可留空
|
||||||
|
b_iso: Dict[str, float] = None,
|
||||||
|
# 保留小数位
|
||||||
|
ndigits: int = 4
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
把结构改写成含 x (相对于 Y 总数) 的 Li/Y 反位缺陷的“平均占据”CIF 块。
|
||||||
|
返回字符串,可直接粘贴到 CIF 的 atom_site loop 中。
|
||||||
|
"""
|
||||||
|
# 默认使用表 S9(P-3̅m1)的坐标与占据;multiplicity 依据该空间群
|
||||||
|
if y_sites is None:
|
||||||
|
y_sites = [
|
||||||
|
("Y1", "1a", (0.0000, 0.0000, 0.0000), 1.0000, 1),
|
||||||
|
("Y2", "2d", (0.3333, 0.6666, 0.4880), 0.8269, 2),
|
||||||
|
("Y3", "2d", (0.3333, 0.6666,-0.0650), 0.1730, 2),
|
||||||
|
]
|
||||||
|
if li_sites is None:
|
||||||
|
li_sites = [
|
||||||
|
("Li1", "6g", (0.3397, 0.3397, 0.0000), 1.0000, 6),
|
||||||
|
("Li2", "6h", (0.3397, 0.3397, 0.5000), 0.5000, 6),
|
||||||
|
]
|
||||||
|
if b_iso is None:
|
||||||
|
# 可按需要改
|
||||||
|
b_iso = {"Y": 1.10, "Li": 5.00, "Cl": 1.99}
|
||||||
|
|
||||||
|
# 1) 总 Y 与总 Li(每原胞的“原子数贡献”):
|
||||||
|
y_total = sum(m * occ for _,_,_,occ,m in y_sites) # 期望为 3
|
||||||
|
li_total = sum(m * occ for _,_,_,occ,m in li_sites) # 期望为 9
|
||||||
|
|
||||||
|
# 2) 需要交换的“Y 数量”(相对于每原胞)
|
||||||
|
y_exchanged = x * y_total # 例如 x=0.0556 → ~0.1668 个 Y/原胞
|
||||||
|
|
||||||
|
# 3) 折算到 Li 子格的分配比例(保证电中性和计量)
|
||||||
|
li_fraction = y_exchanged / li_total # 每个 Li 位点应引入的 Y 的比例
|
||||||
|
|
||||||
|
# 4) 生成 CIF loop 文本(同一坐标写两行,分别为主元素与反位元素)
|
||||||
|
lines = []
|
||||||
|
header = [
|
||||||
|
"loop_",
|
||||||
|
" _atom_site_label",
|
||||||
|
" _atom_site_type_symbol",
|
||||||
|
" _atom_site_fract_x",
|
||||||
|
" _atom_site_fract_y",
|
||||||
|
" _atom_site_fract_z",
|
||||||
|
" _atom_site_occupancy",
|
||||||
|
" _atom_site_B_iso_or_equiv",
|
||||||
|
]
|
||||||
|
lines.extend(header)
|
||||||
|
|
||||||
|
# 4a) 处理 Y 位点:Y → (1-x)*occ;Li_on_Y → x*occ
|
||||||
|
for label, wyck, (fx,fy,fz), occ, mult in y_sites:
|
||||||
|
y_main_occ = round(occ * (1.0 - x), ndigits)
|
||||||
|
li_on_y_occ = round(occ * x, ndigits)
|
||||||
|
lines.append(f" {label}_main Y {fx:.4f} {fy:.4f} {fz:.4f} {y_main_occ:.{ndigits}f} {b_iso.get('Y',1.0)}")
|
||||||
|
lines.append(f" Li_on_{label} Li {fx:.4f} {fy:.4f} {fz:.4f} {li_on_y_occ:.{ndigits}f} {b_iso.get('Li',5.0)}")
|
||||||
|
|
||||||
|
# 4b) 处理 Li 位点:Li → (1-li_fraction)*occ;Y_on_Li → li_fraction*occ
|
||||||
|
for label, wyck, (fx,fy,fz), occ, mult in li_sites:
|
||||||
|
li_main_occ = round(occ * (1.0 - li_fraction), ndigits)
|
||||||
|
y_on_li_occ = round(occ * li_fraction, ndigits)
|
||||||
|
lines.append(f" {label}_main Li {fx:.4f} {fy:.4f} {fz:.4f} {li_main_occ:.{ndigits}f} {b_iso.get('Li',5.0)}")
|
||||||
|
lines.append(f" Y_on_{label} Y {fx:.4f} {fy:.4f} {fz:.4f} {y_on_li_occ:.{ndigits}f} {b_iso.get('Y',1.0)}")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
# === 示例 ===
|
||||||
|
# x=0.0556(与论文使用的 5.56% 一致,定义为相对于 Y 总数的比例):
|
||||||
|
print(add_li_y_antisite_average(0.0556))
|
||||||
21
mcp/SearchPaperByEmbedding/LICENSE
Normal file
21
mcp/SearchPaperByEmbedding/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2025 gyj155
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
96
mcp/SearchPaperByEmbedding/README.md
Normal file
96
mcp/SearchPaperByEmbedding/README.md
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
# Paper Semantic Search
|
||||||
|
|
||||||
|
Find similar papers using semantic search. Supports both local models (free) and OpenAI API (better quality).
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- Request for papers from OpenReview (e.g., ICLR2026 submissions)
|
||||||
|
- Semantic search with example papers or text queries
|
||||||
|
- Support embedding caching
|
||||||
|
- Embed model support: Open-source (e.g., all-MiniLM-L6-v2) or OpenAI
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
### 1. Prepare Papers
|
||||||
|
|
||||||
|
```python
|
||||||
|
from crawl import crawl_papers
|
||||||
|
|
||||||
|
crawl_papers(
|
||||||
|
venue_id="ICLR.cc/2026/Conference/Submission",
|
||||||
|
output_file="iclr2026_papers.json"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Search Papers
|
||||||
|
|
||||||
|
```python
|
||||||
|
from search import PaperSearcher
|
||||||
|
|
||||||
|
# Local model (free)
|
||||||
|
searcher = PaperSearcher('iclr2026_papers.json', model_type='local')
|
||||||
|
|
||||||
|
# OpenAI model (better, requires API key)
|
||||||
|
# export OPENAI_API_KEY='your-key'
|
||||||
|
# searcher = PaperSearcher('iclr2026_papers.json', model_type='openai')
|
||||||
|
|
||||||
|
searcher.compute_embeddings()
|
||||||
|
|
||||||
|
# Search with example papers that you are interested in
|
||||||
|
examples = [
|
||||||
|
{
|
||||||
|
"title": "Your paper title",
|
||||||
|
"abstract": "Your paper abstract..."
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
results = searcher.search(examples=examples, top_k=100)
|
||||||
|
|
||||||
|
# Or search with text query
|
||||||
|
results = searcher.search(query="interesting topics", top_k=100)
|
||||||
|
|
||||||
|
searcher.display(results, n=10)
|
||||||
|
searcher.save(results, 'results.json')
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## How It Works
|
||||||
|
|
||||||
|
1. Paper titles and abstracts are converted to embeddings
|
||||||
|
2. Embeddings are cached automatically
|
||||||
|
3. Your query is embedded using the same model
|
||||||
|
4. Cosine similarity finds the most similar papers
|
||||||
|
5. Results are ranked by similarity score
|
||||||
|
|
||||||
|
## Cache
|
||||||
|
|
||||||
|
Embeddings are cached as `cache_<filename>_<hash>_<model>.npy`. Delete to recompute.
|
||||||
|
|
||||||
|
## Example Output
|
||||||
|
|
||||||
|
```
|
||||||
|
================================================================================
|
||||||
|
Top 100 Results (showing 10)
|
||||||
|
================================================================================
|
||||||
|
|
||||||
|
1. [0.8456] Paper a
|
||||||
|
#12345 | foundation or frontier models, including LLMs
|
||||||
|
https://openreview.net/forum?id=xxx
|
||||||
|
|
||||||
|
2. [0.8234] Paper b
|
||||||
|
#12346 | applications to robotics, autonomy, planning
|
||||||
|
https://openreview.net/forum?id=yyy
|
||||||
|
```
|
||||||
|
|
||||||
|
## Tips
|
||||||
|
|
||||||
|
- Use 1-5 example papers for best results, or a paragraph of description of your interested topic
|
||||||
|
- Local model is good enough for most cases
|
||||||
|
- OpenAI model for critical search (~$1 for 18k queries)
|
||||||
|
|
||||||
|
If it's useful, please consider giving a star~
|
||||||
66
mcp/SearchPaperByEmbedding/crawl.py
Normal file
66
mcp/SearchPaperByEmbedding/crawl.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
import requests
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
|
||||||
|
def fetch_submissions(venue_id, offset=0, limit=1000):
|
||||||
|
url = "https://api2.openreview.net/notes"
|
||||||
|
params = {
|
||||||
|
"content.venueid": venue_id,
|
||||||
|
"details": "replyCount,invitation",
|
||||||
|
"limit": limit,
|
||||||
|
"offset": offset,
|
||||||
|
"sort": "number:desc"
|
||||||
|
}
|
||||||
|
headers = {"User-Agent": "Mozilla/5.0"}
|
||||||
|
response = requests.get(url, params=params, headers=headers)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
def crawl_papers(venue_id, output_file):
|
||||||
|
all_papers = []
|
||||||
|
offset = 0
|
||||||
|
limit = 1000
|
||||||
|
|
||||||
|
print(f"Fetching papers from {venue_id}...")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
data = fetch_submissions(venue_id, offset, limit)
|
||||||
|
notes = data.get("notes", [])
|
||||||
|
|
||||||
|
if not notes:
|
||||||
|
break
|
||||||
|
|
||||||
|
for note in notes:
|
||||||
|
paper = {
|
||||||
|
"id": note.get("id"),
|
||||||
|
"number": note.get("number"),
|
||||||
|
"title": note.get("content", {}).get("title", {}).get("value", ""),
|
||||||
|
"authors": note.get("content", {}).get("authors", {}).get("value", []),
|
||||||
|
"abstract": note.get("content", {}).get("abstract", {}).get("value", ""),
|
||||||
|
"keywords": note.get("content", {}).get("keywords", {}).get("value", []),
|
||||||
|
"primary_area": note.get("content", {}).get("primary_area", {}).get("value", ""),
|
||||||
|
"forum_url": f"https://openreview.net/forum?id={note.get('id')}"
|
||||||
|
}
|
||||||
|
all_papers.append(paper)
|
||||||
|
|
||||||
|
print(f"Fetched {len(notes)} papers (total: {len(all_papers)})")
|
||||||
|
|
||||||
|
if len(notes) < limit:
|
||||||
|
break
|
||||||
|
|
||||||
|
offset += limit
|
||||||
|
time.sleep(0.5)
|
||||||
|
|
||||||
|
with open(output_file, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(all_papers, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
print(f"\nTotal: {len(all_papers)} papers")
|
||||||
|
print(f"Saved to {output_file}")
|
||||||
|
return all_papers
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
crawl_papers(
|
||||||
|
venue_id="ICLR.cc/2026/Conference/Submission",
|
||||||
|
output_file="iclr2026_papers.json"
|
||||||
|
)
|
||||||
|
|
||||||
22
mcp/SearchPaperByEmbedding/demo.py
Normal file
22
mcp/SearchPaperByEmbedding/demo.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
from search import PaperSearcher
|
||||||
|
|
||||||
|
# Use local model (free)
|
||||||
|
searcher = PaperSearcher('iclr2026_papers.json', model_type='local')
|
||||||
|
|
||||||
|
# Or use OpenAI (better quality)
|
||||||
|
# searcher = PaperSearcher('iclr2026_papers.json', model_type='openai')
|
||||||
|
|
||||||
|
searcher.compute_embeddings()
|
||||||
|
|
||||||
|
examples = [
|
||||||
|
{
|
||||||
|
"title": "Solid-State battery",
|
||||||
|
"abstract": "Solid-State battery"
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
results = searcher.search(examples=examples, top_k=100)
|
||||||
|
|
||||||
|
searcher.display(results, n=10)
|
||||||
|
searcher.save(results, 'results.json')
|
||||||
|
|
||||||
6
mcp/SearchPaperByEmbedding/requirements.txt
Normal file
6
mcp/SearchPaperByEmbedding/requirements.txt
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
requests
|
||||||
|
numpy
|
||||||
|
scikit-learn
|
||||||
|
sentence-transformers
|
||||||
|
openai
|
||||||
|
|
||||||
156
mcp/SearchPaperByEmbedding/search.py
Normal file
156
mcp/SearchPaperByEmbedding/search.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import hashlib
|
||||||
|
from pathlib import Path
|
||||||
|
from sklearn.metrics.pairwise import cosine_similarity
|
||||||
|
|
||||||
|
class PaperSearcher:
|
||||||
|
def __init__(self, papers_file, model_type="openai", api_key=None, base_url=None):
|
||||||
|
with open(papers_file, 'r', encoding='utf-8') as f:
|
||||||
|
self.papers = json.load(f)
|
||||||
|
|
||||||
|
self.model_type = model_type
|
||||||
|
self.cache_file = self._get_cache_file(papers_file, model_type)
|
||||||
|
self.embeddings = None
|
||||||
|
|
||||||
|
if model_type == "openai":
|
||||||
|
from openai import OpenAI
|
||||||
|
self.client = OpenAI(
|
||||||
|
api_key=api_key or os.getenv('OPENAI_API_KEY'),
|
||||||
|
base_url=base_url
|
||||||
|
)
|
||||||
|
self.model_name = "text-embedding-3-large"
|
||||||
|
else:
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
self.model = SentenceTransformer('all-MiniLM-L6-v2')
|
||||||
|
self.model_name = "all-MiniLM-L6-v2"
|
||||||
|
|
||||||
|
self._load_cache()
|
||||||
|
|
||||||
|
def _get_cache_file(self, papers_file, model_type):
|
||||||
|
base_name = Path(papers_file).stem
|
||||||
|
file_hash = hashlib.md5(papers_file.encode()).hexdigest()[:8]
|
||||||
|
cache_name = f"cache_{base_name}_{file_hash}_{model_type}.npy"
|
||||||
|
return str(Path(papers_file).parent / cache_name)
|
||||||
|
|
||||||
|
def _load_cache(self):
|
||||||
|
if os.path.exists(self.cache_file):
|
||||||
|
try:
|
||||||
|
self.embeddings = np.load(self.cache_file)
|
||||||
|
if len(self.embeddings) == len(self.papers):
|
||||||
|
print(f"Loaded cache: {self.embeddings.shape}")
|
||||||
|
return True
|
||||||
|
self.embeddings = None
|
||||||
|
except:
|
||||||
|
self.embeddings = None
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _save_cache(self):
|
||||||
|
np.save(self.cache_file, self.embeddings)
|
||||||
|
print(f"Saved cache: {self.cache_file}")
|
||||||
|
|
||||||
|
def _create_text(self, paper):
|
||||||
|
parts = []
|
||||||
|
if paper.get('title'):
|
||||||
|
parts.append(f"Title: {paper['title']}")
|
||||||
|
if paper.get('abstract'):
|
||||||
|
parts.append(f"Abstract: {paper['abstract']}")
|
||||||
|
if paper.get('keywords'):
|
||||||
|
kw = ', '.join(paper['keywords']) if isinstance(paper['keywords'], list) else paper['keywords']
|
||||||
|
parts.append(f"Keywords: {kw}")
|
||||||
|
return ' '.join(parts)
|
||||||
|
|
||||||
|
def _embed_openai(self, texts):
|
||||||
|
if isinstance(texts, str):
|
||||||
|
texts = [texts]
|
||||||
|
|
||||||
|
embeddings = []
|
||||||
|
batch_size = 100
|
||||||
|
|
||||||
|
for i in range(0, len(texts), batch_size):
|
||||||
|
batch = texts[i:i + batch_size]
|
||||||
|
response = self.client.embeddings.create(input=batch, model=self.model_name)
|
||||||
|
embeddings.extend([item.embedding for item in response.data])
|
||||||
|
|
||||||
|
return np.array(embeddings)
|
||||||
|
|
||||||
|
def _embed_local(self, texts):
|
||||||
|
if isinstance(texts, str):
|
||||||
|
texts = [texts]
|
||||||
|
return self.model.encode(texts, show_progress_bar=len(texts) > 100)
|
||||||
|
|
||||||
|
def compute_embeddings(self, force=False):
|
||||||
|
if self.embeddings is not None and not force:
|
||||||
|
print("Using cached embeddings")
|
||||||
|
return self.embeddings
|
||||||
|
|
||||||
|
print(f"Computing embeddings ({self.model_name})...")
|
||||||
|
texts = [self._create_text(p) for p in self.papers]
|
||||||
|
|
||||||
|
if self.model_type == "openai":
|
||||||
|
self.embeddings = self._embed_openai(texts)
|
||||||
|
else:
|
||||||
|
self.embeddings = self._embed_local(texts)
|
||||||
|
|
||||||
|
print(f"Computed: {self.embeddings.shape}")
|
||||||
|
self._save_cache()
|
||||||
|
return self.embeddings
|
||||||
|
|
||||||
|
def search(self, examples=None, query=None, top_k=100):
|
||||||
|
if self.embeddings is None:
|
||||||
|
self.compute_embeddings()
|
||||||
|
|
||||||
|
if examples:
|
||||||
|
texts = []
|
||||||
|
for ex in examples:
|
||||||
|
text = f"Title: {ex['title']}"
|
||||||
|
if ex.get('abstract'):
|
||||||
|
text += f" Abstract: {ex['abstract']}"
|
||||||
|
texts.append(text)
|
||||||
|
|
||||||
|
if self.model_type == "openai":
|
||||||
|
embs = self._embed_openai(texts)
|
||||||
|
else:
|
||||||
|
embs = self._embed_local(texts)
|
||||||
|
|
||||||
|
query_emb = np.mean(embs, axis=0).reshape(1, -1)
|
||||||
|
|
||||||
|
elif query:
|
||||||
|
if self.model_type == "openai":
|
||||||
|
query_emb = self._embed_openai(query).reshape(1, -1)
|
||||||
|
else:
|
||||||
|
query_emb = self._embed_local(query).reshape(1, -1)
|
||||||
|
else:
|
||||||
|
raise ValueError("Provide either examples or query")
|
||||||
|
|
||||||
|
similarities = cosine_similarity(query_emb, self.embeddings)[0]
|
||||||
|
top_indices = np.argsort(similarities)[::-1][:top_k]
|
||||||
|
|
||||||
|
return [{
|
||||||
|
'paper': self.papers[idx],
|
||||||
|
'similarity': float(similarities[idx])
|
||||||
|
} for idx in top_indices]
|
||||||
|
|
||||||
|
def display(self, results, n=10):
|
||||||
|
print(f"\n{'='*80}")
|
||||||
|
print(f"Top {len(results)} Results (showing {min(n, len(results))})")
|
||||||
|
print(f"{'='*80}\n")
|
||||||
|
|
||||||
|
for i, result in enumerate(results[:n], 1):
|
||||||
|
paper = result['paper']
|
||||||
|
sim = result['similarity']
|
||||||
|
|
||||||
|
print(f"{i}. [{sim:.4f}] {paper['title']}")
|
||||||
|
print(f" #{paper.get('number', 'N/A')} | {paper.get('primary_area', 'N/A')}")
|
||||||
|
print(f" {paper['forum_url']}\n")
|
||||||
|
|
||||||
|
def save(self, results, output_file):
|
||||||
|
with open(output_file, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump({
|
||||||
|
'model': self.model_name,
|
||||||
|
'total': len(results),
|
||||||
|
'results': results
|
||||||
|
}, f, ensure_ascii=False, indent=2)
|
||||||
|
print(f"Saved to {output_file}")
|
||||||
|
|
||||||
50
mcp/lifespan.py
Normal file
50
mcp/lifespan.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
import asyncssh
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from starlette.applications import Starlette
|
||||||
|
# --- 1. SSH 连接参数 ---
|
||||||
|
REMOTE_HOST = '202.121.182.208'
|
||||||
|
REMOTE_USER = 'koko125'
|
||||||
|
# 确保路径正确,Windows 路径建议使用正斜杠 '/'
|
||||||
|
PRIVATE_KEY_PATH = 'D:/tool/tool/id_rsa.txt'
|
||||||
|
INITIAL_WORKING_DIRECTORY = f'/cluster/home/{REMOTE_USER}/sandbox'
|
||||||
|
|
||||||
|
|
||||||
|
# --- 2. 定义共享上下文 ---
|
||||||
|
# 这个数据类将持有所有模块需要共享的资源
|
||||||
|
@dataclass
|
||||||
|
class SharedAppContext:
|
||||||
|
"""在整个服务器生命周期内持有的共享资源"""
|
||||||
|
ssh_connection: asyncssh.SSHClientConnection
|
||||||
|
sandbox_path: str
|
||||||
|
|
||||||
|
|
||||||
|
# --- 3. 创建生命周期管理器 ---
|
||||||
|
# 这是整个架构的核心,负责在服务器启动时连接SSH,在关闭时断开
|
||||||
|
@asynccontextmanager
|
||||||
|
async def shared_lifespan(app: Starlette) -> AsyncIterator[SharedAppContext]:
|
||||||
|
"""
|
||||||
|
为整个应用管理共享资源的生命周期。
|
||||||
|
"""
|
||||||
|
print("主应用启动,正在建立共享 SSH 连接...")
|
||||||
|
conn = None
|
||||||
|
try:
|
||||||
|
# 建立 SSH 连接
|
||||||
|
conn = await asyncssh.connect(
|
||||||
|
REMOTE_HOST,
|
||||||
|
username=REMOTE_USER,
|
||||||
|
client_keys=[PRIVATE_KEY_PATH]
|
||||||
|
)
|
||||||
|
print(f"SSH 连接到 {REMOTE_HOST} 成功!")
|
||||||
|
|
||||||
|
# 使用 yield 将创建好的共享资源上下文传递给 Starlette 应用
|
||||||
|
# 服务器会在此处暂停并开始处理请求
|
||||||
|
yield {"shared_context": SharedAppContext(ssh_connection=conn, sandbox_path=INITIAL_WORKING_DIRECTORY)}
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# 当服务器关闭时,yield 之后的代码会被执行
|
||||||
|
if conn:
|
||||||
|
conn.close()
|
||||||
|
await conn.wait_closed()
|
||||||
|
print("共享 SSH 连接已关闭。")
|
||||||
55
mcp/main.py
Normal file
55
mcp/main.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
# main.py
|
||||||
|
# 将 Materials CIF MCP 与 System Tools MCP 一起挂载到 Starlette。
|
||||||
|
# 关键点:
|
||||||
|
# - 在 lifespan 中启动每个 MCP 的 session_manager.run()(参考 SDK README 的 Starlette 挂载示例与 streamable_http_app 用法 [1])
|
||||||
|
# - 通过 Mount 指定各自的子路径(如 /system 与 /materials)
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
from starlette.applications import Starlette
|
||||||
|
from starlette.routing import Mount
|
||||||
|
|
||||||
|
from system_tools import create_system_mcp
|
||||||
|
from materialproject_mcp import create_materials_mcp
|
||||||
|
from softBV_remake import create_softbv_mcp
|
||||||
|
from paper_search_mcp import create_paper_search_mcp
|
||||||
|
from topological_analysis_models import create_topological_analysis_mcp
|
||||||
|
|
||||||
|
# 创建 MCP 实例
|
||||||
|
system_mcp = create_system_mcp()
|
||||||
|
materials_mcp = create_materials_mcp()
|
||||||
|
softbv_mcp = create_softbv_mcp()
|
||||||
|
paper_search_mcp = create_paper_search_mcp()
|
||||||
|
topological_analysis_mcp = create_topological_analysis_mcp()
|
||||||
|
# 在 Starlette 的 lifespan 中启动 MCP 的 session manager
|
||||||
|
@contextlib.asynccontextmanager
|
||||||
|
async def lifespan(app: Starlette):
|
||||||
|
async with contextlib.AsyncExitStack() as stack:
|
||||||
|
await stack.enter_async_context(system_mcp.session_manager.run())
|
||||||
|
await stack.enter_async_context(materials_mcp.session_manager.run())
|
||||||
|
await stack.enter_async_context(softbv_mcp.session_manager.run())
|
||||||
|
await stack.enter_async_context(paper_search_mcp.session_manager.run())
|
||||||
|
await stack.enter_async_context(topological_analysis_mcp.session_manager.run())
|
||||||
|
yield # 服务器运行期间
|
||||||
|
# 退出时自动清理
|
||||||
|
|
||||||
|
# 挂载两个 MCP 的 Streamable HTTP App
|
||||||
|
app = Starlette(
|
||||||
|
lifespan=lifespan,
|
||||||
|
routes=[
|
||||||
|
Mount("/system", app=system_mcp.streamable_http_app()),
|
||||||
|
Mount("/materials", app=materials_mcp.streamable_http_app()),
|
||||||
|
Mount("/softBV", app=softbv_mcp.streamable_http_app()),
|
||||||
|
Mount("/papersearch",app=paper_search_mcp.streamable_http_app()),
|
||||||
|
Mount("/topologicalAnalysis",app=topological_analysis_mcp.streamable_http_app()),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# 启动命令(终端执行):
|
||||||
|
# uvicorn main:app --host 0.0.0.0 --port 8000
|
||||||
|
# 访问:
|
||||||
|
# http://localhost:8000/system
|
||||||
|
# http://localhost:8000/materials
|
||||||
|
# http://localhost:8000/softBV
|
||||||
|
# http://localhost:8000/papersearch
|
||||||
|
# http://localhost:8000/topologicalAnalysis
|
||||||
|
# 如果需要浏览器客户端访问(CORS 暴露 Mcp-Session-Id),请参考 README 中的 CORS 配置示例 [1]
|
||||||
513
mcp/materialproject_mcp.py
Normal file
513
mcp/materialproject_mcp.py
Normal file
@@ -0,0 +1,513 @@
|
|||||||
|
# materials_mp_cif_mcp.py
|
||||||
|
#
|
||||||
|
# 说明:
|
||||||
|
# - 这是一个基于 FastMCP 的“Materials Project CIF”服务器,提供:
|
||||||
|
# 1) search_cifs:使用 Materials Project 的 mp-api 进行检索(两步:summary 过滤 + materials 获取结构)
|
||||||
|
# 2) get_cif:按材料 ID 获取结构并导出 CIF 文本(pymatgen)
|
||||||
|
# 3) filter_cifs:在本地对候选进行硬筛选
|
||||||
|
# 4) download_cif_bulk:批量下载 CIF,带并发与进度
|
||||||
|
# 5) 资源:cif://mp/{id} 直接读取 CIF 文本
|
||||||
|
#
|
||||||
|
# - 依赖:pip install "mcp[cli]" mp_api pymatgen
|
||||||
|
# - API Key:
|
||||||
|
# 默认从环境变量 MP_API_KEY 读取;若未设置则回退为你提供的 Key(仅用于演示,不建议在生产中硬编码)[3]
|
||||||
|
# - MCP/Context/Lifespan 的使用方法参考 SDK 文档(工具、结构化输出、进度等)[1]
|
||||||
|
#
|
||||||
|
# 引用:
|
||||||
|
# - MPRester/MP_API_KEY 的用法与推荐的会话管理方式见 Materials Project 文档 [3]
|
||||||
|
# - summary.search 与 materials.search 的查询与 fields 限制示例见文档 [4]
|
||||||
|
# - FastMCP 的工具、资源、Context、结构化输出见 README 示例 [1]
|
||||||
|
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Literal
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
import re
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from mp_api.client import MPRester # [3]
|
||||||
|
from pymatgen.core import Structure # 用于转 CIF
|
||||||
|
from mcp.server.fastmcp import FastMCP, Context # [1]
|
||||||
|
from mcp.server.session import ServerSession # [1]
|
||||||
|
|
||||||
|
# ========= 配置 =========
|
||||||
|
|
||||||
|
# 生产中请通过环境变量 MP_API_KEY 配置;此处为演示而提供回退到你给的 API Key(不建议硬编码)[3]
|
||||||
|
DEFAULT_MP_API_KEY = os.getenv("MP_API_KEY", "X0NqhxkFeupGy7k09HRmLS6PI4VmvTBW")
|
||||||
|
|
||||||
|
# ========= 数据模型 =========
|
||||||
|
|
||||||
|
class LatticeParameters(BaseModel):
|
||||||
|
a: float | None = Field(default=None, description="a 轴长度 (Å)")
|
||||||
|
b: float | None = Field(default=None, description="b 轴长度 (Å)")
|
||||||
|
c: float | None = Field(default=None, description="c 轴长度 (Å)")
|
||||||
|
alpha: float | None = Field(default=None, description="α 角 (度)")
|
||||||
|
beta: float | None = Field(default=None, description="β 角 (度)")
|
||||||
|
gamma: float | None = Field(default=None, description="γ 角 (度)")
|
||||||
|
|
||||||
|
|
||||||
|
class CIFCandidate(BaseModel):
|
||||||
|
id: str = Field(description="Materials Project ID,如 mp-149")
|
||||||
|
source: Literal["mp"] = Field(default="mp", description="来源库")
|
||||||
|
chemical_formula: str | None = Field(default=None, description="化学式(pretty)")
|
||||||
|
space_group: str | None = Field(default=None, description="空间群符号")
|
||||||
|
cell_parameters: LatticeParameters | None = Field(default=None, description="晶格参数")
|
||||||
|
bandgap: float | None = Field(default=None, description="带隙 (eV)")
|
||||||
|
stability_energy: float | None = Field(default=None, description="E_above_hull (eV)")
|
||||||
|
magnetic_ordering: str | None = Field(default=None, description="磁性信息")
|
||||||
|
has_cif: bool = Field(default=True, description="是否存在可导出的 CIF(只要有结构即可)")
|
||||||
|
extras: dict[str, Any] | None = Field(default=None, description="原始文档片段(健壮解析)")
|
||||||
|
|
||||||
|
|
||||||
|
class SearchQuery(BaseModel):
|
||||||
|
# 基础化学约束
|
||||||
|
include_elements: list[str] | None = Field(default=None, description="包含元素集合(如 ['Si','O'])") # [4]
|
||||||
|
exclude_elements: list[str] | None = Field(default=None, description="排除元素集合(如 ['Li'])")
|
||||||
|
formula: str | None = Field(default=None, description="化学式片段(本地二次过滤)")
|
||||||
|
|
||||||
|
# 结构与性能约束(summary 可过滤带隙与 E_above_hull)[4]
|
||||||
|
bandgap_min: float | None = None
|
||||||
|
bandgap_max: float | None = None
|
||||||
|
e_above_hull_max: float | None = None
|
||||||
|
|
||||||
|
# 空间群、晶格范围(在本地二次过滤)
|
||||||
|
space_group: str | None = None
|
||||||
|
cell_a_min: float | None = None
|
||||||
|
cell_a_max: float | None = None
|
||||||
|
cell_b_min: float | None = None
|
||||||
|
cell_b_max: float | None = None
|
||||||
|
cell_c_min: float | None = None
|
||||||
|
cell_c_max: float | None = None
|
||||||
|
|
||||||
|
magnetic_ordering: str | None = None # 本地二次过滤
|
||||||
|
|
||||||
|
page: int = Field(default=1, ge=1, description="页码(仅本地分页;MP 端点请自行扩展)")
|
||||||
|
per_page: int = Field(default=50, ge=1, le=200, description="每页数量(本地分页)")
|
||||||
|
sort_by: str | None = Field(default=None, description="排序字段(bandgap/e_above_hull等,本地排序)")
|
||||||
|
sort_order: Literal["asc", "desc"] | None = Field(default=None, description="排序方向")
|
||||||
|
|
||||||
|
|
||||||
|
class CIFText(BaseModel):
|
||||||
|
id: str
|
||||||
|
source: Literal["mp"]
|
||||||
|
text: str = Field(description="CIF 文件文本内容")
|
||||||
|
metadata: dict[str, Any] | None = Field(default=None, description="附加元数据(如结构来源)")
|
||||||
|
|
||||||
|
|
||||||
|
class FilterCriteria(BaseModel):
|
||||||
|
include_elements: list[str] | None = None
|
||||||
|
exclude_elements: list[str] | None = None
|
||||||
|
formula: str | None = None
|
||||||
|
space_group: str | None = None
|
||||||
|
cell_a_min: float | None = None
|
||||||
|
cell_a_max: float | None = None
|
||||||
|
cell_b_min: float | None = None
|
||||||
|
cell_b_max: float | None = None
|
||||||
|
cell_c_min: float | None = None
|
||||||
|
cell_c_max: float | None = None
|
||||||
|
bandgap_min: float | None = None
|
||||||
|
bandgap_max: float | None = None
|
||||||
|
e_above_hull_max: float | None = None
|
||||||
|
magnetic_ordering: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class BulkDownloadItem(BaseModel):
|
||||||
|
id: str
|
||||||
|
source: Literal["mp"] = "mp"
|
||||||
|
success: bool
|
||||||
|
text: str | None = None
|
||||||
|
error: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class BulkDownloadResult(BaseModel):
|
||||||
|
items: list[BulkDownloadItem]
|
||||||
|
total: int
|
||||||
|
succeeded: int
|
||||||
|
failed: int
|
||||||
|
|
||||||
|
|
||||||
|
# ========= Lifespan =========
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AppContext:
|
||||||
|
mpr: MPRester
|
||||||
|
api_key: str
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(_server: FastMCP) -> AsyncIterator[AppContext]:
|
||||||
|
# 推荐使用 MP_API_KEY 环境变量;此处回退到提供的 Key 字符串 [3]
|
||||||
|
api_key = DEFAULT_MP_API_KEY
|
||||||
|
mpr = MPRester(api_key) # 会话管理在 mp-api 内部 [3]
|
||||||
|
try:
|
||||||
|
yield AppContext(mpr=mpr, api_key=api_key)
|
||||||
|
finally:
|
||||||
|
# mp-api 未强制要求手动关闭;若未来提供关闭方法,可在此清理 [3]
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# ========= 服务器 =========
|
||||||
|
|
||||||
|
mcp = FastMCP(
|
||||||
|
name="Materials Project CIF",
|
||||||
|
instructions="通过 Materials Project API 检索并导出 CIF 文件。",
|
||||||
|
lifespan=lifespan,
|
||||||
|
streamable_http_path="/", # 便于在 Starlette 下挂载到子路径 [1]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ========= 工具实现 =========
|
||||||
|
|
||||||
|
def _match_elements_set(text: str | None, required: list[str] | None, excluded: list[str] | None) -> bool:
|
||||||
|
if text is None:
|
||||||
|
return not required and not excluded
|
||||||
|
s = text.upper()
|
||||||
|
if required:
|
||||||
|
for el in required:
|
||||||
|
if el.upper() not in s:
|
||||||
|
return False
|
||||||
|
if excluded:
|
||||||
|
for el in excluded:
|
||||||
|
if el.upper() in s:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _lattice_from_structure(struct: Structure | None) -> LatticeParameters | None:
|
||||||
|
if struct is None:
|
||||||
|
return None
|
||||||
|
lat = struct.lattice
|
||||||
|
return LatticeParameters(
|
||||||
|
a=float(lat.a),
|
||||||
|
b=float(lat.b),
|
||||||
|
c=float(lat.c),
|
||||||
|
alpha=float(lat.alpha),
|
||||||
|
beta=float(lat.beta),
|
||||||
|
gamma=float(lat.gamma),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _sg_symbol_from_doc(doc: Any) -> str | None:
|
||||||
|
"""
|
||||||
|
尝试从 materials.search 返回的文档中提取空间群符号。
|
||||||
|
MP 文档中 symmetry 通常包含 spacegroup 信息(不同版本字段路径可能变化)[4]
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
sym = getattr(doc, "symmetry", None)
|
||||||
|
if sym and getattr(sym, "spacegroup", None):
|
||||||
|
sg = sym.spacegroup
|
||||||
|
# 常见字段:symbol 或 international
|
||||||
|
return getattr(sg, "symbol", None) or getattr(sg, "international", None)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def _to_thread(fn, *args, **kwargs):
|
||||||
|
return await asyncio.to_thread(fn, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_elements_from_formula(formula: str) -> list[str]:
|
||||||
|
"""从化学式中解析元素符号(简易版)。"""
|
||||||
|
if not formula:
|
||||||
|
return []
|
||||||
|
elems = re.findall(r"[A-Z][a-z]?", formula)
|
||||||
|
seen, ordered = set(), []
|
||||||
|
for e in elems:
|
||||||
|
if e not in seen:
|
||||||
|
seen.add(e)
|
||||||
|
ordered.append(e)
|
||||||
|
return ordered
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def search_cifs(
|
||||||
|
query: SearchQuery,
|
||||||
|
ctx: Context[ServerSession, AppContext],
|
||||||
|
) -> list[CIFCandidate]:
|
||||||
|
"""
|
||||||
|
通过 summary 做初筛(限制返回数量),再用 materials 拉取当前页的结构与对称性。
|
||||||
|
修复点:
|
||||||
|
- 仅请求当前页(chunk_size/per_page + num_chunks=1),避免全库下载超时
|
||||||
|
- 绝不在返回值中包含 mp-api 文档对象或 Structure,extras 只放可序列化轻量字段
|
||||||
|
"""
|
||||||
|
app = ctx.request_context.lifespan_context
|
||||||
|
mpr = app.mpr
|
||||||
|
|
||||||
|
# 将 formula 自动转为元素集合(若用户未显式提供)
|
||||||
|
include_elements = query.include_elements or _parse_elements_from_formula(query.formula or "")
|
||||||
|
|
||||||
|
# 至少要有一个约束,避免全库扫描
|
||||||
|
if not include_elements and not query.exclude_elements and query.bandgap_min is None \
|
||||||
|
and query.bandgap_max is None and query.e_above_hull_max is None:
|
||||||
|
raise ValueError("请至少提供一个过滤条件,例如 include_elements、bandgap 或 energy_above_hull 上限,以避免超时。")
|
||||||
|
|
||||||
|
# 1) summary 初筛:只取“当前页”的文档(轻字段)
|
||||||
|
summary_kwargs: dict[str, Any] = {}
|
||||||
|
if include_elements:
|
||||||
|
summary_kwargs["elements"] = include_elements # MP 支持按元素集合检索
|
||||||
|
if query.exclude_elements:
|
||||||
|
summary_kwargs["exclude_elements"] = query.exclude_elements
|
||||||
|
if query.bandgap_min is not None or query.bandgap_max is not None:
|
||||||
|
summary_kwargs["band_gap"] = (
|
||||||
|
float(query.bandgap_min) if query.bandgap_min is not None else None,
|
||||||
|
float(query.bandgap_max) if query.bandgap_max is not None else None,
|
||||||
|
)
|
||||||
|
if query.e_above_hull_max is not None:
|
||||||
|
summary_kwargs["energy_above_hull"] = (None, float(query.e_above_hull_max))
|
||||||
|
|
||||||
|
fields_summary = [
|
||||||
|
"material_id",
|
||||||
|
"formula_pretty",
|
||||||
|
"band_gap",
|
||||||
|
"energy_above_hull",
|
||||||
|
"ordering",
|
||||||
|
]
|
||||||
|
|
||||||
|
# 关键:限制 summary 返回规模
|
||||||
|
try:
|
||||||
|
summary_docs = await asyncio.to_thread(
|
||||||
|
mpr.materials.summary.search,
|
||||||
|
fields=fields_summary,
|
||||||
|
chunk_size=query.per_page, # 每块(页)大小
|
||||||
|
num_chunks=1, # 只取一块 => 只取当前页数量
|
||||||
|
**summary_kwargs,
|
||||||
|
)
|
||||||
|
except TypeError:
|
||||||
|
# 如果你的 mp-api 版本不支持上述参数,提示升级,避免一次性抓回十几万条导致超时
|
||||||
|
raise RuntimeError(
|
||||||
|
"当前 mp-api 版本不支持 chunk_size/num_chunks,请升级 mp-api(pip install -U mp_api)。"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not summary_docs:
|
||||||
|
await ctx.debug("summary 检索结果为空")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 本地排序(如需)
|
||||||
|
results_all: list[CIFCandidate] = []
|
||||||
|
for sdoc in summary_docs:
|
||||||
|
e_hull = getattr(sdoc, "energy_above_hull", None)
|
||||||
|
bandgap = getattr(sdoc, "band_gap", None)
|
||||||
|
ordering = getattr(sdoc, "ordering", None)
|
||||||
|
|
||||||
|
# 暂不放 structure/doc 到 extras,避免序列化错误
|
||||||
|
results_all.append(
|
||||||
|
CIFCandidate(
|
||||||
|
id=sdoc.material_id,
|
||||||
|
source="mp",
|
||||||
|
chemical_formula=getattr(sdoc, "formula_pretty", None),
|
||||||
|
space_group=None,
|
||||||
|
cell_parameters=None,
|
||||||
|
bandgap=float(bandgap) if bandgap is not None else None,
|
||||||
|
stability_energy=float(e_hull) if e_hull is not None else None,
|
||||||
|
magnetic_ordering=ordering,
|
||||||
|
has_cif=True,
|
||||||
|
extras={
|
||||||
|
"summary": {
|
||||||
|
"material_id": sdoc.material_id,
|
||||||
|
"formula_pretty": getattr(sdoc, "formula_pretty", None),
|
||||||
|
"band_gap": float(bandgap) if bandgap is not None else None,
|
||||||
|
"energy_above_hull": float(e_hull) if e_hull is not None else None,
|
||||||
|
"ordering": ordering,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 本地分页:summary 已限制返回条数,这里只做安全截断和排序
|
||||||
|
if query.sort_by:
|
||||||
|
reverse = query.sort_order == "desc"
|
||||||
|
key_map = {
|
||||||
|
"bandgap": lambda c: (c.bandgap if c.bandgap is not None else float("inf")),
|
||||||
|
"e_above_hull": lambda c: (c.stability_energy if c.stability_energy is not None else float("inf")),
|
||||||
|
"energy_above_hull": lambda c: (c.stability_energy if c.stability_energy is not None else float("inf")),
|
||||||
|
}
|
||||||
|
if query.sort_by in key_map:
|
||||||
|
results_all.sort(key=key_map[query.sort_by], reverse=reverse)
|
||||||
|
|
||||||
|
page_slice = results_all[: query.per_page]
|
||||||
|
page_ids = [c.id for c in page_slice]
|
||||||
|
|
||||||
|
# 2) 只为当前页材料获取结构与对称性
|
||||||
|
if not page_ids:
|
||||||
|
return []
|
||||||
|
|
||||||
|
fields_materials = ["material_id", "structure", "symmetry"]
|
||||||
|
materials_docs = await asyncio.to_thread(
|
||||||
|
mpr.materials.search,
|
||||||
|
material_ids=page_ids,
|
||||||
|
fields=fields_materials,
|
||||||
|
)
|
||||||
|
mat_by_id = {doc.material_id: doc for doc in materials_docs}
|
||||||
|
|
||||||
|
# 合并结构/空间群,仍然不返回不可序列化对象
|
||||||
|
merged: list[CIFCandidate] = []
|
||||||
|
for c in page_slice:
|
||||||
|
mdoc = mat_by_id.get(c.id)
|
||||||
|
struct = getattr(mdoc, "structure", None) if mdoc else None
|
||||||
|
|
||||||
|
# 空间群符号
|
||||||
|
sg = None
|
||||||
|
try:
|
||||||
|
sym = getattr(mdoc, "symmetry", None)
|
||||||
|
if sym and getattr(sym, "spacegroup", None):
|
||||||
|
sgobj = sym.spacegroup
|
||||||
|
sg = getattr(sgobj, "symbol", None) or getattr(sgobj, "international", None)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
c.space_group = sg
|
||||||
|
c.cell_parameters = _lattice_from_structure(struct)
|
||||||
|
c.has_cif = struct is not None
|
||||||
|
# 保持 extras 为轻量 JSON
|
||||||
|
c.extras = {
|
||||||
|
**(c.extras or {}),
|
||||||
|
"materials": {"has_structure": struct is not None},
|
||||||
|
}
|
||||||
|
merged.append(c)
|
||||||
|
|
||||||
|
await ctx.debug(f"返回 {len(merged)} 条(已限制每页 {query.per_page} 条)")
|
||||||
|
return merged
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def get_cif(
|
||||||
|
id: str,
|
||||||
|
ctx: Context[ServerSession, AppContext],
|
||||||
|
) -> CIFText:
|
||||||
|
"""
|
||||||
|
使用 materials.search 获取结构并导出 CIF 文本。
|
||||||
|
- fields=["structure"] 以提升速度 [4]
|
||||||
|
- 结构到 CIF 通过 pymatgen 的 Structure.to(fmt="cif")
|
||||||
|
"""
|
||||||
|
app = ctx.request_context.lifespan_context
|
||||||
|
mpr = app.mpr
|
||||||
|
|
||||||
|
await ctx.info(f"获取结构并导出 CIF:{id}")
|
||||||
|
docs = await _to_thread(mpr.materials.search, material_ids=[id], fields=["material_id", "structure"]) # [4]
|
||||||
|
if not docs:
|
||||||
|
raise ValueError(f"未找到材料:{id}")
|
||||||
|
|
||||||
|
doc = docs[0]
|
||||||
|
struct: Structure | None = getattr(doc, "structure", None)
|
||||||
|
if struct is None:
|
||||||
|
raise ValueError(f"材料 {id} 无结构数据,无法导出 CIF")
|
||||||
|
|
||||||
|
cif_text = struct.to(fmt="cif") # 直接输出 CIF 文本
|
||||||
|
return CIFText(id=id, source="mp", text=cif_text, metadata={"from": "materials.search"})
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def filter_cifs(
|
||||||
|
candidates: list[CIFCandidate],
|
||||||
|
criteria: FilterCriteria,
|
||||||
|
ctx: Context[ServerSession, AppContext],
|
||||||
|
) -> list[CIFCandidate]:
|
||||||
|
def in_range(v: float | None, vmin: float | None, vmax: float | None) -> bool:
|
||||||
|
if v is None:
|
||||||
|
return vmin is None and vmax is None
|
||||||
|
if vmin is not None and v < vmin:
|
||||||
|
return False
|
||||||
|
if vmax is not None and v > vmax:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
filtered: list[CIFCandidate] = []
|
||||||
|
for c in candidates:
|
||||||
|
# 元素集合近似匹配(基于化学式字符串)
|
||||||
|
if not _match_elements_set(c.chemical_formula, criteria.include_elements, criteria.exclude_elements):
|
||||||
|
continue
|
||||||
|
# 化学式包含判断
|
||||||
|
if criteria.formula:
|
||||||
|
if not c.chemical_formula or criteria.formula.upper() not in c.chemical_formula.upper():
|
||||||
|
continue
|
||||||
|
# 空间群
|
||||||
|
if criteria.space_group:
|
||||||
|
spg = (c.space_group or "").lower()
|
||||||
|
if criteria.space_group.lower() not in spg:
|
||||||
|
continue
|
||||||
|
# 晶格范围
|
||||||
|
cp = c.cell_parameters or LatticeParameters()
|
||||||
|
if not in_range(cp.a, criteria.cell_a_min, criteria.cell_a_max):
|
||||||
|
continue
|
||||||
|
if not in_range(cp.b, criteria.cell_b_min, criteria.cell_b_max):
|
||||||
|
continue
|
||||||
|
if not in_range(cp.c, criteria.cell_c_min, criteria.cell_c_max):
|
||||||
|
continue
|
||||||
|
# 带隙与稳定性
|
||||||
|
if not in_range(c.bandgap, criteria.bandgap_min, criteria.bandgap_max):
|
||||||
|
continue
|
||||||
|
if criteria.e_above_hull_max is not None:
|
||||||
|
if c.stability_energy is None or c.stability_energy > criteria.e_above_hull_max:
|
||||||
|
continue
|
||||||
|
# 磁性
|
||||||
|
if criteria.magnetic_ordering:
|
||||||
|
mo = "" # 当前未填充,保留兼容
|
||||||
|
if criteria.magnetic_ordering.upper() not in mo.upper():
|
||||||
|
continue
|
||||||
|
|
||||||
|
filtered.append(c)
|
||||||
|
|
||||||
|
await ctx.debug(f"筛选后:{len(filtered)} / 原始 {len(candidates)}")
|
||||||
|
return filtered
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def download_cif_bulk(
|
||||||
|
ids: list[str],
|
||||||
|
concurrency: int = 4,
|
||||||
|
ctx: Context[ServerSession, AppContext] | None = None,
|
||||||
|
) -> BulkDownloadResult:
|
||||||
|
semaphore = asyncio.Semaphore(max(1, concurrency))
|
||||||
|
items: list[BulkDownloadItem] = []
|
||||||
|
|
||||||
|
async def worker(mid: str, idx: int) -> None:
|
||||||
|
nonlocal items
|
||||||
|
try:
|
||||||
|
async with semaphore:
|
||||||
|
# 复用当前 ctx
|
||||||
|
assert ctx is not None
|
||||||
|
cif = await get_cif(id=mid, ctx=ctx)
|
||||||
|
items.append(BulkDownloadItem(id=mid, success=True, text=cif.text))
|
||||||
|
except Exception as e:
|
||||||
|
items.append(BulkDownloadItem(id=mid, success=False, error=str(e)))
|
||||||
|
if ctx is not None:
|
||||||
|
progress = (idx + 1) / len(ids) if len(ids) > 0 else 1.0
|
||||||
|
await ctx.report_progress(progress=progress, total=1.0, message=f"{idx + 1}/{len(ids)}") # [1]
|
||||||
|
|
||||||
|
tasks = [worker(mid, i) for i, mid in enumerate(ids)]
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
succeeded = sum(1 for it in items if it.success)
|
||||||
|
failed = len(items) - succeeded
|
||||||
|
return BulkDownloadResult(items=items, total=len(items), succeeded=succeeded, failed=failed)
|
||||||
|
|
||||||
|
|
||||||
|
# ========= 资源:cif://mp/{id} =========
|
||||||
|
|
||||||
|
@mcp.resource("cif://mp/{id}")
|
||||||
|
async def cif_resource(id: str) -> str:
|
||||||
|
"""
|
||||||
|
基于 URI 的 CIF 内容读取。资源函数无法直接拿到 Context,
|
||||||
|
这里简单地新的 MPRester 会话读取结构并转 CIF(仅资源路径用途)。
|
||||||
|
"""
|
||||||
|
# 资源中使用一次性会话(不复用 lifespan),适合偶发读取 [1]
|
||||||
|
mpr = MPRester(DEFAULT_MP_API_KEY) # [3]
|
||||||
|
try:
|
||||||
|
docs = mpr.materials.search(material_ids=[id], fields=["material_id", "structure"]) # [4]
|
||||||
|
if not docs or getattr(docs[0], "structure", None) is None:
|
||||||
|
raise ValueError(f"材料 {id} 无结构数据")
|
||||||
|
struct: Structure = docs[0].structure
|
||||||
|
return struct.to(fmt="cif")
|
||||||
|
finally:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def create_materials_mcp() -> FastMCP:
|
||||||
|
"""供 Starlette 挂载的工厂函数。"""
|
||||||
|
return mcp
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 独立运行(非挂载模式),使用 Streamable HTTP 传输 [1]
|
||||||
|
mcp.run(transport="streamable-http")
|
||||||
172
mcp/paper_search_mcp.py
Normal file
172
mcp/paper_search_mcp.py
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
# paper_search_mcp.py (重构版)
|
||||||
|
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
# 假设您的 search.py 提供了 PaperSearcher 类
|
||||||
|
from SearchPaperByEmbedding.search import PaperSearcher
|
||||||
|
|
||||||
|
from mcp.server.fastmcp import FastMCP, Context
|
||||||
|
from mcp.server.session import ServerSession
|
||||||
|
|
||||||
|
# ========= 1. 配置与常量 (修正版) =========
|
||||||
|
|
||||||
|
# 获取当前脚本文件所在的目录的绝对路径
|
||||||
|
_CURRENT_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# 基于项目根目录构造资源文件的绝对路径
|
||||||
|
DATA_DIR = os.path.join(_CURRENT_SCRIPT_DIR,"SearchPaperByEmbedding")
|
||||||
|
PAPERS_JSON_PATH = os.path.join(DATA_DIR, "iclr2026_papers.json")
|
||||||
|
|
||||||
|
# 打印路径用于调试,确保它是正确的
|
||||||
|
print(f"DEBUG: Trying to load papers from: {PAPERS_JSON_PATH}")
|
||||||
|
|
||||||
|
MODEL_TYPE = os.getenv("EMBEDDING_MODEL_TYPE", "local")
|
||||||
|
|
||||||
|
|
||||||
|
# ========= 2. 数据模型 =========
|
||||||
|
|
||||||
|
class Paper(BaseModel):
|
||||||
|
"""单篇论文的详细信息"""
|
||||||
|
title: str
|
||||||
|
authors: list[str]
|
||||||
|
abstract: str
|
||||||
|
pdf_url: str | None = None
|
||||||
|
similarity_score: float = Field(description="与查询的相似度分数,越高越相关")
|
||||||
|
|
||||||
|
|
||||||
|
class SearchQuery(BaseModel):
|
||||||
|
"""论文搜索的查询结构"""
|
||||||
|
title: str
|
||||||
|
abstract: str | None = Field(None, description="可选的摘要,以提供更丰富的上下文")
|
||||||
|
|
||||||
|
|
||||||
|
# ========= 3. Lifespan (资源加载与管理) =========
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PaperSearchContext:
|
||||||
|
"""在服务器生命周期内共享的、已初始化的 PaperSearcher 实例"""
|
||||||
|
searcher: PaperSearcher
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def paper_search_lifespan(_server: FastMCP) -> AsyncIterator[PaperSearchContext]:
|
||||||
|
"""
|
||||||
|
FastMCP 生命周期:在服务器启动时初始化 PaperSearcher 并计算 Embeddings。
|
||||||
|
这确保了最耗时的部分只在启动时执行一次。
|
||||||
|
"""
|
||||||
|
print(f"正在使用 '{MODEL_TYPE}' 模型初始化论文搜索引擎...")
|
||||||
|
|
||||||
|
# 使用 to_thread 运行同步的初始化代码
|
||||||
|
def _initialize_searcher():
|
||||||
|
# 1. 初始化
|
||||||
|
searcher = PaperSearcher(
|
||||||
|
PAPERS_JSON_PATH,
|
||||||
|
model_type=MODEL_TYPE
|
||||||
|
)
|
||||||
|
# 2. 计算/加载 Embeddings
|
||||||
|
print("正在计算或加载论文 Embeddings...")
|
||||||
|
searcher.compute_embeddings()
|
||||||
|
print("Embeddings 加载完成。")
|
||||||
|
return searcher
|
||||||
|
|
||||||
|
searcher = await asyncio.to_thread(_initialize_searcher)
|
||||||
|
print("论文搜索引擎已准备就绪!")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 使用 yield 将创建好的共享资源上下文传递给 MCP
|
||||||
|
yield PaperSearchContext(searcher=searcher)
|
||||||
|
finally:
|
||||||
|
print("论文搜索引擎服务关闭。")
|
||||||
|
|
||||||
|
|
||||||
|
# ========= 4. MCP 服务器实例 =========
|
||||||
|
|
||||||
|
mcp = FastMCP(
|
||||||
|
name="Paper Search",
|
||||||
|
instructions="通过 embedding 相似度检索 ICLR 2026 论文的工具。",
|
||||||
|
lifespan=paper_search_lifespan,
|
||||||
|
streamable_http_path="/",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ========= 5. 工具实现 =========
|
||||||
|
@mcp.tool()
|
||||||
|
async def search_papers(
|
||||||
|
queries: list[SearchQuery],
|
||||||
|
top_k: int = 3,
|
||||||
|
ctx: Context[ServerSession, PaperSearchContext] | None = None,
|
||||||
|
) -> list[Paper]:
|
||||||
|
"""
|
||||||
|
根据给定的一个或多个查询(包含标题和可选的摘要),使用语义相似度搜索论文。
|
||||||
|
返回最相关的 top_k 个结果。
|
||||||
|
"""
|
||||||
|
if ctx is None:
|
||||||
|
raise ValueError("Context is required for this operation.")
|
||||||
|
if not queries:
|
||||||
|
return []
|
||||||
|
|
||||||
|
app_ctx = ctx.request_context.lifespan_context
|
||||||
|
searcher = app_ctx.searcher
|
||||||
|
|
||||||
|
# 预处理查询,确保 title 和 abstract 至少有一个存在
|
||||||
|
examples_to_search = []
|
||||||
|
for q in queries:
|
||||||
|
query_dict = q.model_dump(exclude_none=True)
|
||||||
|
if "title" not in query_dict and "abstract" in query_dict:
|
||||||
|
query_dict["title"] = query_dict["abstract"]
|
||||||
|
elif "abstract" not in query_dict and "title" in query_dict:
|
||||||
|
query_dict["abstract"] = query_dict["title"]
|
||||||
|
examples_to_search.append(query_dict)
|
||||||
|
|
||||||
|
if not examples_to_search:
|
||||||
|
return []
|
||||||
|
|
||||||
|
query_display = queries[0].title or queries[0].abstract or "Empty Query"
|
||||||
|
await ctx.info(f"正在以查询 '{query_display}' 等内容搜索排名前 {top_k} 的论文...")
|
||||||
|
|
||||||
|
# 调用底层的 search 方法
|
||||||
|
results = await asyncio.to_thread(
|
||||||
|
searcher.search,
|
||||||
|
examples=examples_to_search,
|
||||||
|
top_k=top_k
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- 关键修正:按照 search.py 返回的正确结构进行解析 ---
|
||||||
|
papers = []
|
||||||
|
for res in results:
|
||||||
|
paper_data = res.get('paper', {}) # 获取嵌套的论文信息字典
|
||||||
|
similarity = res.get('similarity', 0.0) # 获取相似度分数
|
||||||
|
|
||||||
|
papers.append(
|
||||||
|
Paper(
|
||||||
|
title=paper_data.get("title", "N/A"),
|
||||||
|
authors=paper_data.get("authors", []),
|
||||||
|
abstract=paper_data.get("abstract", ""),
|
||||||
|
# 注意:原始数据中 pdf url 的键是 'forum_url'
|
||||||
|
pdf_url=paper_data.get("forum_url"),
|
||||||
|
similarity_score=similarity # 使用正确的相似度分数
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
await ctx.debug(f"找到 {len(papers)} 篇相关论文。")
|
||||||
|
return papers
|
||||||
|
|
||||||
|
|
||||||
|
# ========= 6. 工厂函数 (用于 Starlette 集成) =========
|
||||||
|
|
||||||
|
def create_paper_search_mcp() -> FastMCP:
|
||||||
|
"""供 Starlette 挂载的工厂函数。"""
|
||||||
|
return mcp
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 如果直接运行此文件,则启动一个独立的 MCP 服务器
|
||||||
|
mcp.run(transport="streamable-http")
|
||||||
464
mcp/softBV.py
Normal file
464
mcp/softBV.py
Normal file
@@ -0,0 +1,464 @@
|
|||||||
|
# softbv_mcp.py
|
||||||
|
#
|
||||||
|
# 在远程服务器上激活 softBV 环境并执行计算(支持 --md 与 --gen-cube 两个专用工具)。
|
||||||
|
# - 生命周期:建立 SSH 连接并注入上下文
|
||||||
|
# - 激活环境:source /cluster/home/koko125/script/softBV.sh
|
||||||
|
# - 工作目录:/cluster/home/koko125/sandbox
|
||||||
|
# - 可执行文件:/cluster/home/koko125/tool/softBV-GUI_linux/bin/softBV.x
|
||||||
|
#
|
||||||
|
# 依赖:
|
||||||
|
# pip install "mcp[cli]" asyncssh pydantic
|
||||||
|
#
|
||||||
|
# 用法(Starlette 挂载示例见你现有 main.py,导入 create_softbv_mcp 即可):
|
||||||
|
# from softbv_mcp import create_softbv_mcp
|
||||||
|
# softbv_mcp = create_softbv_mcp()
|
||||||
|
# Mount("/softbv", app=softbv_mcp.streamable_http_app())
|
||||||
|
#
|
||||||
|
# 可通过环境变量覆盖连接信息:
|
||||||
|
# REMOTE_HOST, REMOTE_USER, PRIVATE_KEY_PATH, SOFTBV_PROFILE, SOFTBV_BIN, DEFAULT_WORKDIR
|
||||||
|
|
||||||
|
import os
|
||||||
|
import posixpath
|
||||||
|
import asyncio
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from socket import socket
|
||||||
|
from typing import Any
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
import asyncssh
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from mcp.server.fastmcp import FastMCP, Context
|
||||||
|
from mcp.server.session import ServerSession
|
||||||
|
|
||||||
|
from lifespan import REMOTE_HOST, REMOTE_USER, PRIVATE_KEY_PATH
|
||||||
|
|
||||||
|
def shell_quote(arg: str) -> str:
|
||||||
|
"""安全地把字符串作为单个 shell 参数(POSIX)。"""
|
||||||
|
return "'" + str(arg).replace("'", "'\"'\"'") + "'"
|
||||||
|
|
||||||
|
# 如果你已经定义了这些常量与路径,可保持复用;也可按需改为你自己的配置源
|
||||||
|
|
||||||
|
|
||||||
|
# 固定的 softBV 环境信息(可用环境变量覆盖)
|
||||||
|
SOFTBV_PROFILE = os.getenv("SOFTBV_PROFILE", "/cluster/home/koko125/script/softBV.sh")
|
||||||
|
SOFTBV_BIN = os.getenv("SOFTBV_BIN", "/cluster/home/koko125/tool/softBV-GUI_linux/bin/softBV.x")
|
||||||
|
DEFAULT_WORKDIR = os.getenv("DEFAULT_WORKDIR", "/cluster/home/koko125/sandbox")
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SoftBVContext:
|
||||||
|
ssh_connection: asyncssh.SSHClientConnection
|
||||||
|
workdir: str
|
||||||
|
profile: str
|
||||||
|
bin_path: str
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def softbv_lifespan(_server) -> AsyncIterator[SoftBVContext]:
|
||||||
|
"""
|
||||||
|
FastMCP 生命周期:建立 SSH 连接并注入 softBV 上下文。
|
||||||
|
- 不再做额外的 DNS 解析或自定义异步步骤,避免 socket.SOCK_STREAM 的环境覆盖问题
|
||||||
|
- 仅负责连接与清理;工具中通过 ctx.request_context.lifespan_context 访问该上下文 [1]
|
||||||
|
"""
|
||||||
|
# 允许用环境变量覆盖连接信息
|
||||||
|
host = os.getenv("REMOTE_HOST", REMOTE_HOST)
|
||||||
|
user = os.getenv("REMOTE_USER", REMOTE_USER)
|
||||||
|
key_path = os.getenv("PRIVATE_KEY_PATH", PRIVATE_KEY_PATH)
|
||||||
|
|
||||||
|
conn: asyncssh.SSHClientConnection | None = None
|
||||||
|
try:
|
||||||
|
conn = await asyncssh.connect(
|
||||||
|
host,
|
||||||
|
username=user,
|
||||||
|
client_keys=[key_path],
|
||||||
|
known_hosts=None, # 如需主机指纹校验,可移除此参数
|
||||||
|
connect_timeout=15, # 避免长时间挂起
|
||||||
|
)
|
||||||
|
yield SoftBVContext(
|
||||||
|
ssh_connection=conn,
|
||||||
|
workdir=DEFAULT_WORKDIR,
|
||||||
|
profile=SOFTBV_PROFILE,
|
||||||
|
bin_path=SOFTBV_BIN,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
if conn:
|
||||||
|
conn.close()
|
||||||
|
await conn.wait_closed()
|
||||||
|
async def run_in_softbv_env(
|
||||||
|
conn: asyncssh.SSHClientConnection,
|
||||||
|
profile_path: str,
|
||||||
|
cmd: str,
|
||||||
|
cwd: str | None = None,
|
||||||
|
check: bool = True,
|
||||||
|
) -> asyncssh.SSHCompletedProcess:
|
||||||
|
"""
|
||||||
|
在远端 bash 会话中激活 softBV 环境并执行 cmd。
|
||||||
|
如提供 cwd,则先 cd 到该目录。
|
||||||
|
"""
|
||||||
|
parts = []
|
||||||
|
if cwd:
|
||||||
|
parts.append(f"cd {shell_quote(cwd)}")
|
||||||
|
parts.append(f"source {shell_quote(profile_path)}")
|
||||||
|
parts.append(cmd)
|
||||||
|
composite = "; ".join(parts)
|
||||||
|
full = f"bash -lc {shell_quote(composite)}"
|
||||||
|
return await conn.run(full, check=check)
|
||||||
|
|
||||||
|
# ===== MCP 服务器 =====
|
||||||
|
mcp = FastMCP(
|
||||||
|
name="softBV Tools",
|
||||||
|
instructions="在远程服务器上激活 softBV 环境并执行相关计算的工具集。",
|
||||||
|
lifespan=softbv_lifespan,
|
||||||
|
streamable_http_path="/",
|
||||||
|
)
|
||||||
|
|
||||||
|
# ===== 辅助:列目录用于识别新生成文件 =====
|
||||||
|
async def _listdir(conn: asyncssh.SSHClientConnection, path: str) -> list[str]:
|
||||||
|
async with conn.start_sftp_client() as sftp:
|
||||||
|
try:
|
||||||
|
return await sftp.listdir(path)
|
||||||
|
except Exception:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# ===== 结构化输入模型:--md =====
|
||||||
|
class SoftBVMDArgs(BaseModel):
|
||||||
|
input_cif: str = Field(description="远程 CIF 文件路径(相对或绝对),作为 --md 的输入")
|
||||||
|
# 位置参数(按帮助中的顺序;None 表示不提供,让程序使用默认)
|
||||||
|
type: str | None = Field(default=None, description="conducting ion 类型(例如 'Li')")
|
||||||
|
os: int | None = Field(default=None, description="conducting ion 氧化态(整数)")
|
||||||
|
sf: float | None = Field(default=None, description="screening factor(非正值使用默认)")
|
||||||
|
temperature: float | None = Field(default=None, description="温度 K(非正值默认 300)")
|
||||||
|
t_end: float | None = Field(default=None, description="生产时间 ps(非正值默认 10.0)")
|
||||||
|
t_equil: float | None = Field(default=None, description="平衡时间 ps(非正值默认 2.0)")
|
||||||
|
dt: float | None = Field(default=None, description="时间步长 ps(非正值默认 0.001)")
|
||||||
|
t_log: float | None = Field(default=None, description="采样间隔 ps(非正值每 100 步)")
|
||||||
|
cwd: str | None = Field(default=None, description="远程工作目录(默认使用生命周期中的 workdir)")
|
||||||
|
|
||||||
|
def _build_md_cmd(bin_path: str, args: SoftBVMDArgs, workdir: str) -> str:
|
||||||
|
input_abs = args.input_cif
|
||||||
|
if not input_abs.startswith("/"):
|
||||||
|
input_abs = posixpath.normpath(posixpath.join(workdir, args.input_cif))
|
||||||
|
parts: list[str] = [shell_quote(bin_path), shell_quote("--md"), shell_quote(input_abs)]
|
||||||
|
for val in [args.type, args.os, args.sf, args.temperature, args.t_end, args.t_equil, args.dt, args.t_log]:
|
||||||
|
if val is not None:
|
||||||
|
parts.append(shell_quote(str(val)))
|
||||||
|
return " ".join(parts)
|
||||||
|
|
||||||
|
# ===== 结构化输入模型:--gen-cube =====
|
||||||
|
class SoftBVGenCubeArgs(BaseModel):
|
||||||
|
input_cif: str = Field(description="远程 CIF 文件路径(相对或绝对),作为 --gen-cube 的输入")
|
||||||
|
type: str | None = Field(default=None, description="conducting ion 类型(如 'Li')")
|
||||||
|
os: int | None = Field(default=None, description="conducting ion 氧化态(整数)")
|
||||||
|
sf: float | None = Field(default=None, description="screening factor(非正值使用默认)")
|
||||||
|
resolution: float | None = Field(default=None, description="体素分辨率(默认约 0.1)")
|
||||||
|
ignore_conducting_ion: bool = Field(default=False, description="flag:ignore_conducting_ion")
|
||||||
|
periodic: bool = Field(default=True, description="flag:periodic(默认 True)")
|
||||||
|
output_name: str | None = Field(default=None, description="输出文件名前缀(可选)")
|
||||||
|
cwd: str | None = Field(default=None, description="远程工作目录(默认使用生命周期中的 workdir)")
|
||||||
|
|
||||||
|
def _build_gen_cube_cmd(bin_path: str, args: SoftBVGenCubeArgs, workdir: str) -> str:
|
||||||
|
input_abs = args.input_cif
|
||||||
|
if not input_abs.startswith("/"):
|
||||||
|
input_abs = posixpath.normpath(posixpath.join(workdir, args.input_cif))
|
||||||
|
parts: list[str] = [shell_quote(bin_path), shell_quote("--gen-cube"), shell_quote(input_abs)]
|
||||||
|
for val in [args.type, args.os, args.sf, args.resolution]:
|
||||||
|
if val is not None:
|
||||||
|
parts.append(shell_quote(str(val)))
|
||||||
|
if args.ignore_conducting_ion:
|
||||||
|
parts.append(shell_quote("--flag:ignore_conducting_ion"))
|
||||||
|
if args.periodic:
|
||||||
|
parts.append(shell_quote("--flag:periodic"))
|
||||||
|
if args.output_name:
|
||||||
|
parts.append(shell_quote(args.output_name))
|
||||||
|
return " ".join(parts)
|
||||||
|
|
||||||
|
# ===== 工具:环境信息检查 =====
|
||||||
|
# 工具:环境信息检查(修复版,避免超时)
|
||||||
|
@mcp.tool()
|
||||||
|
async def softbv_info(ctx: Context[ServerSession, SoftBVContext]) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
快速自检:
|
||||||
|
- SFTP 检查工作目录、激活脚本、二进制是否存在/可执行(无需运行 softBV.x)
|
||||||
|
- 激活环境后仅输出标记与当前工作目录,避免长输出或阻塞
|
||||||
|
"""
|
||||||
|
app = ctx.request_context.lifespan_context
|
||||||
|
conn = app.ssh_connection
|
||||||
|
|
||||||
|
# 1) 通过 SFTP 快速检查文件与目录状态(不会长时间阻塞)
|
||||||
|
def stat_path_safe(path: str) -> dict[str, Any]:
|
||||||
|
return {"exists": False, "is_exec": False, "size": None}
|
||||||
|
|
||||||
|
workdir_info = stat_path_safe(app.workdir)
|
||||||
|
profile_info = stat_path_safe(app.profile)
|
||||||
|
bin_info = stat_path_safe(app.bin_path)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with conn.start_sftp_client() as sftp:
|
||||||
|
# workdir
|
||||||
|
try:
|
||||||
|
attrs = await sftp.stat(app.workdir)
|
||||||
|
workdir_info["exists"] = True
|
||||||
|
workdir_info["size"] = int(attrs.size or 0)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# profile
|
||||||
|
try:
|
||||||
|
attrs = await sftp.stat(app.profile)
|
||||||
|
profile_info["exists"] = True
|
||||||
|
profile_info["size"] = int(attrs.size or 0)
|
||||||
|
perms = int(attrs.permissions or 0)
|
||||||
|
profile_info["is_exec"] = bool(perms & 0o111)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# bin
|
||||||
|
try:
|
||||||
|
attrs = await sftp.stat(app.bin_path)
|
||||||
|
bin_info["exists"] = True
|
||||||
|
bin_info["size"] = int(attrs.size or 0)
|
||||||
|
perms = int(attrs.permissions or 0)
|
||||||
|
bin_info["is_exec"] = bool(perms & 0o111)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
await ctx.warning(f"SFTP 检查失败: {e}")
|
||||||
|
|
||||||
|
# 2) 激活环境并做极简命令(避免 softBV.x --help 的长输出)
|
||||||
|
# 仅返回当前用户、PWD 与二进制可执行判断;不实际运行 softBV.x
|
||||||
|
cmd = "echo __SOFTBV_READY__ && echo $USER && pwd && (test -x " + shell_quote(app.bin_path) + " && echo __BIN_OK__ || echo __BIN_NOT_EXEC__)"
|
||||||
|
proc = await run_in_softbv_env(conn, app.profile, cmd=cmd, cwd=app.workdir, check=False)
|
||||||
|
|
||||||
|
# 解析输出行
|
||||||
|
lines = proc.stdout.splitlines() if proc.stdout else []
|
||||||
|
ready = "__SOFTBV_READY__" in lines
|
||||||
|
user = None
|
||||||
|
pwd = None
|
||||||
|
bin_ok = "__BIN_OK__" in lines
|
||||||
|
|
||||||
|
# 尝试定位 user/pwd(ready 之后的两行)
|
||||||
|
if ready:
|
||||||
|
idx = lines.index("__SOFTBV_READY__")
|
||||||
|
if len(lines) > idx + 1:
|
||||||
|
user = lines[idx + 1].strip()
|
||||||
|
if len(lines) > idx + 2:
|
||||||
|
pwd = lines[idx + 2].strip()
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"host": os.getenv("REMOTE_HOST", REMOTE_HOST),
|
||||||
|
"user": os.getenv("REMOTE_USER", REMOTE_USER),
|
||||||
|
"workdir": app.workdir,
|
||||||
|
"profile": app.profile,
|
||||||
|
"bin_path": app.bin_path,
|
||||||
|
"sftp_check": {
|
||||||
|
"workdir": workdir_info,
|
||||||
|
"profile": profile_info,
|
||||||
|
"bin": bin_info,
|
||||||
|
},
|
||||||
|
"activate_ready": ready,
|
||||||
|
"pwd": pwd,
|
||||||
|
"bin_is_executable": bin_ok or bin_info["is_exec"],
|
||||||
|
"exit_status": proc.exit_status,
|
||||||
|
"stderr_head": "\n".join(proc.stderr.splitlines()[:10]) if proc.stderr else "",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 友好日志
|
||||||
|
if not ready:
|
||||||
|
await ctx.warning("softBV 环境未就绪(可能 source 脚本路径问题或权限不足)")
|
||||||
|
if not result["bin_is_executable"]:
|
||||||
|
await ctx.warning("softBV 二进制不可执行或不存在,请检查 bin_path 与权限(chmod +x)")
|
||||||
|
if proc.exit_status != 0 and not proc.stderr:
|
||||||
|
await ctx.debug("命令非零退出但无 stderr,可能是某些子测试返回非零导致")
|
||||||
|
|
||||||
|
return result
|
||||||
|
# ===== 工具:--md =====
|
||||||
|
@mcp.tool()
|
||||||
|
async def softbv_md(req: SoftBVMDArgs, ctx: Context[ServerSession, SoftBVContext]) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
执行 softBV.x --md,返回结构化结果:
|
||||||
|
- cmd/cwd/exit_status/stdout/stderr
|
||||||
|
- new_files:执行后新增文件列表,便于定位输出
|
||||||
|
"""
|
||||||
|
app = ctx.request_context.lifespan_context
|
||||||
|
conn = app.ssh_connection
|
||||||
|
workdir = req.cwd or app.workdir
|
||||||
|
cmd = _build_md_cmd(app.bin_path, req, workdir)
|
||||||
|
|
||||||
|
await ctx.info(f"softBV --md 执行: {cmd} (cwd={workdir})")
|
||||||
|
pre_list = await _listdir(conn, workdir)
|
||||||
|
|
||||||
|
proc = await run_in_softbv_env(conn, app.profile, cmd=cmd, cwd=workdir, check=False)
|
||||||
|
|
||||||
|
post_list = await _listdir(conn, workdir)
|
||||||
|
new_files = sorted(set(post_list) - set(pre_list))
|
||||||
|
if proc.exit_status == 0:
|
||||||
|
await ctx.debug(f"--md 成功,新文件 {len(new_files)} 个")
|
||||||
|
else:
|
||||||
|
await ctx.warning(f"--md 非零退出: {proc.exit_status}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"cmd": cmd,
|
||||||
|
"cwd": workdir,
|
||||||
|
"exit_status": proc.exit_status,
|
||||||
|
"stdout": proc.stdout,
|
||||||
|
"stderr": proc.stderr,
|
||||||
|
"new_files": new_files,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def run_in_softbv_env_stream(
|
||||||
|
conn: asyncssh.SSHClientConnection,
|
||||||
|
profile_path: str,
|
||||||
|
cmd: str,
|
||||||
|
cwd: str | None = None,
|
||||||
|
) -> asyncssh.SSHClientProcess:
|
||||||
|
parts = []
|
||||||
|
if cwd:
|
||||||
|
parts.append(f"cd {shell_quote(cwd)}")
|
||||||
|
parts.append(f"source {shell_quote(profile_path)} >/dev/null 2>&1 || true")
|
||||||
|
parts.append(cmd)
|
||||||
|
composite = "; ".join(parts)
|
||||||
|
full = f"bash -lc {shell_quote(composite)}"
|
||||||
|
# 不阻塞,返回进程句柄
|
||||||
|
proc = await conn.create_process(full)
|
||||||
|
return proc
|
||||||
|
|
||||||
|
# 轮询目录,识别新生成文件
|
||||||
|
async def _listdir(conn: asyncssh.SSHClientConnection, path: str) -> list[str]:
|
||||||
|
async with conn.start_sftp_client() as sftp:
|
||||||
|
try:
|
||||||
|
return await sftp.listdir(path)
|
||||||
|
except Exception:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 轮询日志文件大小(作为心跳/粗略进度依据)
|
||||||
|
async def _stat_size(conn: asyncssh.SSHClientConnection, path: str) -> int | None:
|
||||||
|
async with conn.start_sftp_client() as sftp:
|
||||||
|
try:
|
||||||
|
attrs = await sftp.stat(path)
|
||||||
|
return int(attrs.size or 0)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 构造 gen-cube 命令(保持你之前的参数拼接逻辑)
|
||||||
|
def _build_gen_cube_cmd(bin_path: str, args: SoftBVGenCubeArgs, workdir: str, log_path: str | None = None) -> str:
|
||||||
|
input_abs = args.input_cif
|
||||||
|
if not input_abs.startswith("/"):
|
||||||
|
input_abs = posixpath.normpath(posixpath.join(workdir, args.input_cif))
|
||||||
|
parts: list[str] = [shell_quote(bin_path), shell_quote("--gen-cube"), shell_quote(input_abs)]
|
||||||
|
for val in [args.type, args.os, args.sf, args.resolution]:
|
||||||
|
if val is not None:
|
||||||
|
parts.append(shell_quote(str(val)))
|
||||||
|
if args.ignore_conducting_ion:
|
||||||
|
parts.append(shell_quote("--flag:ignore_conducting_ion"))
|
||||||
|
if args.periodic:
|
||||||
|
parts.append(shell_quote("--flag:periodic"))
|
||||||
|
if args.output_name:
|
||||||
|
parts.append(shell_quote(args.output_name))
|
||||||
|
cmd = " ".join(parts)
|
||||||
|
# 将输出重定向到日志,便于轮询
|
||||||
|
if log_path:
|
||||||
|
cmd = f"{cmd} > {shell_quote(log_path)} 2>&1"
|
||||||
|
return cmd
|
||||||
|
|
||||||
|
# 修复版:长时运行的 softbv_gen_cube,带心跳与超时保护
|
||||||
|
@mcp.tool()
|
||||||
|
async def softbv_gen_cube(req: SoftBVGenCubeArgs, ctx: Context[ServerSession, SoftBVContext]) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
执行 softBV.x --gen-cube,支持长任务心跳,避免在 <25min 内被客户端强制超时。
|
||||||
|
- 每隔 10s 上报一次进度(心跳),包含已用时/日志大小/新增文件数
|
||||||
|
- 结束后返回 stdout_head(来自日志文件片段)、stderr_head(如有)、exit_status、新增文件
|
||||||
|
"""
|
||||||
|
app = ctx.request_context.lifespan_context
|
||||||
|
conn = app.ssh_connection
|
||||||
|
workdir = req.cwd or app.workdir
|
||||||
|
|
||||||
|
# 预先记录目录内容,用于结束后差集
|
||||||
|
before = await _listdir(conn, workdir)
|
||||||
|
|
||||||
|
# 远端日志文件路径(按时间戳命名)
|
||||||
|
import time
|
||||||
|
log_name = f"softbv_gencube_{int(time.time())}.log"
|
||||||
|
log_path = posixpath.join(workdir, log_name)
|
||||||
|
|
||||||
|
# 启动长任务,不阻塞当前协程
|
||||||
|
cmd = _build_gen_cube_cmd(app.bin_path, req, workdir, log_path=log_path)
|
||||||
|
await ctx.info(f"启动 --gen-cube: {cmd}")
|
||||||
|
proc = await run_in_softbv_env_stream(conn, app.profile, cmd=cmd, cwd=workdir)
|
||||||
|
|
||||||
|
# 心跳循环:直到进程退出
|
||||||
|
start_ts = time.time()
|
||||||
|
heartbeat_sec = 10 # 每 10 秒发送一次心跳
|
||||||
|
max_guard_min = 60 # 保险上限(服务端不主动终止;如客户端有限制可调大)
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
# 进程是否已退出
|
||||||
|
if proc.exit_status is not None:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 采集状态:已用时、日志大小、新增文件数
|
||||||
|
elapsed = time.time() - start_ts
|
||||||
|
log_size = await _stat_size(conn, log_path)
|
||||||
|
now_files = await _listdir(conn, workdir)
|
||||||
|
new_files_count = len(set(now_files) - set(before))
|
||||||
|
|
||||||
|
# 这里无法获知真实百分比,使用“心跳/已用时提示”
|
||||||
|
await ctx.report_progress(
|
||||||
|
progress=min(elapsed / (25 * 60), 0.99), # 以 25min 为目标上限做近似刻度
|
||||||
|
total=1.0,
|
||||||
|
message=f"gen-cube 运行中: 已用时 {int(elapsed)}s, 日志 {log_size or 0}B, 新文件 {new_files_count}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 避免客户端超时:持续心跳即可。[1]
|
||||||
|
await asyncio.sleep(heartbeat_sec)
|
||||||
|
|
||||||
|
# 简易守护:超过 max_guard_min 仍未结束也不强制中断(由远端决定)
|
||||||
|
if elapsed > max_guard_min * 60:
|
||||||
|
await ctx.warning("任务已超过守护上限时间,仍在运行(未强制中断)。如需更长时间,请增大上限。")
|
||||||
|
# 不 break,继续等待,交由远端任务完成
|
||||||
|
finally:
|
||||||
|
# 等待进程真正结束(如果已结束,这里是快速返回)
|
||||||
|
await proc.wait()
|
||||||
|
|
||||||
|
# 结束后采集结果
|
||||||
|
exit_status = proc.exit_status
|
||||||
|
after = await _listdir(conn, workdir)
|
||||||
|
new_files = sorted(set(after) - set(before))
|
||||||
|
|
||||||
|
# 读取日志片段(头/尾),帮助定位输出
|
||||||
|
async with conn.start_sftp_client() as sftp:
|
||||||
|
head = ""
|
||||||
|
tail = ""
|
||||||
|
try:
|
||||||
|
async with sftp.open(log_path, "rb") as f:
|
||||||
|
content = await f.read()
|
||||||
|
text = content.decode("utf-8", errors="replace")
|
||||||
|
lines = text.splitlines()
|
||||||
|
head = "\n".join(lines[:40])
|
||||||
|
tail = "\n".join(lines[-40:])
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 输出结构化结果
|
||||||
|
result = {
|
||||||
|
"cmd": cmd,
|
||||||
|
"cwd": workdir,
|
||||||
|
"exit_status": exit_status,
|
||||||
|
"log_file": log_path,
|
||||||
|
"stdout_head": head, # 代替一次性 stdout,避免大输出
|
||||||
|
"stderr_head": "", # 统一日志到文件,stderr_head 可留空
|
||||||
|
"new_files": new_files,
|
||||||
|
"elapsed_sec": int(time.time() - start_ts),
|
||||||
|
}
|
||||||
|
|
||||||
|
if exit_status == 0:
|
||||||
|
await ctx.info(f"gen-cube 完成,用时 {result['elapsed_sec']}s,新文件 {len(new_files)} 个")
|
||||||
|
else:
|
||||||
|
await ctx.warning(f"gen-cube 退出码 {exit_status},请查看日志 {log_path}")
|
||||||
|
|
||||||
|
return result
|
||||||
|
def create_softbv_mcp() -> FastMCP:
|
||||||
|
"""供外部(Starlette)导入的工厂函数。"""
|
||||||
|
return mcp
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
mcp.run(transport="streamable-http")
|
||||||
1659
mcp/softBV_remake.py
Normal file
1659
mcp/softBV_remake.py
Normal file
File diff suppressed because it is too large
Load Diff
407
mcp/system_tools.py
Normal file
407
mcp/system_tools.py
Normal file
@@ -0,0 +1,407 @@
|
|||||||
|
# system_tools.py
|
||||||
|
import posixpath
|
||||||
|
import stat as pystat
|
||||||
|
from shlex import shlex
|
||||||
|
from typing import Any
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
|
||||||
|
import asyncssh
|
||||||
|
from mcp.server.fastmcp import FastMCP, Context
|
||||||
|
from mcp.server.session import ServerSession
|
||||||
|
|
||||||
|
# 在 system_tools.py 顶部添加
|
||||||
|
def shell_quote(arg: str) -> str:
|
||||||
|
"""
|
||||||
|
安全地把字符串作为单个 shell 参数:
|
||||||
|
- 外层用单引号包裹
|
||||||
|
- 内部的单引号 ' 替换为 '\'' 序列
|
||||||
|
适用于远端 Linux shell 命令拼接
|
||||||
|
"""
|
||||||
|
return "'" + arg.replace("'", "'\"'\"'") + "'"
|
||||||
|
|
||||||
|
# 复用你的配置与数据类
|
||||||
|
from lifespan import (
|
||||||
|
SharedAppContext,
|
||||||
|
REMOTE_HOST,
|
||||||
|
REMOTE_USER,
|
||||||
|
PRIVATE_KEY_PATH,
|
||||||
|
INITIAL_WORKING_DIRECTORY,
|
||||||
|
)
|
||||||
|
|
||||||
|
# —— 1) 定义 FastMCP 的生命周期,在启动时建立 SSH 连接,关闭时断开 ——
|
||||||
|
@asynccontextmanager
|
||||||
|
async def system_lifespan(_server: FastMCP) -> AsyncIterator[SharedAppContext]:
|
||||||
|
"""
|
||||||
|
FastMCP 生命周期:建立并注入共享的 SSH 连接与沙箱根路径。
|
||||||
|
说明:这是 MCP 自己的生命周期,工具里通过 ctx.request_context.lifespan_context 访问。
|
||||||
|
"""
|
||||||
|
conn: asyncssh.SSHClientConnection | None = None
|
||||||
|
try:
|
||||||
|
# 建立 SSH 连接
|
||||||
|
conn = await asyncssh.connect(
|
||||||
|
REMOTE_HOST,
|
||||||
|
username=REMOTE_USER,
|
||||||
|
client_keys=[PRIVATE_KEY_PATH],
|
||||||
|
)
|
||||||
|
# 将类型安全的共享上下文注入 MCP 生命周期
|
||||||
|
yield SharedAppContext(ssh_connection=conn, sandbox_path=INITIAL_WORKING_DIRECTORY)
|
||||||
|
finally:
|
||||||
|
# 关闭连接
|
||||||
|
if conn:
|
||||||
|
conn.close()
|
||||||
|
await conn.wait_closed()
|
||||||
|
|
||||||
|
|
||||||
|
def create_system_mcp() -> FastMCP:
|
||||||
|
"""创建一个包含系统操作工具的 MCP 实例。"""
|
||||||
|
system_mcp = FastMCP(
|
||||||
|
name="System Tools",
|
||||||
|
instructions="用于在远程服务器上进行基本文件和目录操作的工具集。",
|
||||||
|
streamable_http_path="/",
|
||||||
|
lifespan=system_lifespan, # 关键:把生命周期传给 FastMCP [1]
|
||||||
|
)
|
||||||
|
|
||||||
|
def _safe_join(sandbox_root: str, relative_path: str) -> str:
|
||||||
|
"""
|
||||||
|
将用户提供的相对路径映射到沙箱根目录内的规范化绝对路径。
|
||||||
|
- 统一使用 POSIX 语义(远端 Linux)
|
||||||
|
- 禁止使用以 '/' 开头的绝对路径
|
||||||
|
- 禁止 '..' 越界,确保最终路径仍在沙箱内
|
||||||
|
"""
|
||||||
|
rel = (relative_path or ".").strip()
|
||||||
|
|
||||||
|
# 禁止绝对路径,转为相对
|
||||||
|
if rel.startswith("/"):
|
||||||
|
rel = rel.lstrip("/")
|
||||||
|
|
||||||
|
# 规范化拼接
|
||||||
|
combined = posixpath.normpath(posixpath.join(sandbox_root, rel))
|
||||||
|
|
||||||
|
# 统一尾部斜杠处理,避免边界判断遗漏
|
||||||
|
root_norm = sandbox_root.rstrip("/")
|
||||||
|
|
||||||
|
# 确保仍在沙箱内
|
||||||
|
if combined != root_norm and not combined.startswith(root_norm + "/"):
|
||||||
|
raise ValueError("路径越界:仅允许访问沙箱目录内部")
|
||||||
|
|
||||||
|
# 禁止路径中出现 '..'(进一步加固)
|
||||||
|
parts = [p for p in combined.split("/") if p]
|
||||||
|
if ".." in parts:
|
||||||
|
raise ValueError("非法路径:不允许使用 '..' 跨目录")
|
||||||
|
|
||||||
|
return combined
|
||||||
|
|
||||||
|
# —— 重构后的各工具:统一用类型安全的 ctx 与 app_ctx 访问共享资源 ——
|
||||||
|
|
||||||
|
@system_mcp.tool()
|
||||||
|
async def list_files(ctx: Context[ServerSession, SharedAppContext], path: str = ".") -> list[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
列出远程沙箱目录中的文件和子目录(结构化输出)。
|
||||||
|
Returns:
|
||||||
|
list[dict]: [{name, path, is_dir, size, permissions, mtime}]
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
app_ctx = ctx.request_context.lifespan_context
|
||||||
|
conn = app_ctx.ssh_connection
|
||||||
|
sandbox_root = app_ctx.sandbox_path
|
||||||
|
|
||||||
|
target = _safe_join(sandbox_root, path)
|
||||||
|
await ctx.info(f"列出目录:{target}")
|
||||||
|
|
||||||
|
items: list[dict[str, Any]] = []
|
||||||
|
async with conn.start_sftp_client() as sftp:
|
||||||
|
names = await sftp.listdir(target) # list[str]
|
||||||
|
for name in names:
|
||||||
|
item_abs = posixpath.join(target, name)
|
||||||
|
attrs = None
|
||||||
|
try:
|
||||||
|
attrs = await sftp.stat(item_abs)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
perms = int(attrs.permissions or 0) if attrs else 0
|
||||||
|
is_dir = bool(pystat.S_ISDIR(perms))
|
||||||
|
size = int(attrs.size or 0) if attrs and attrs.size is not None else 0
|
||||||
|
mtime = int(attrs.mtime or 0) if attrs and attrs.mtime is not None else 0
|
||||||
|
|
||||||
|
items.append({
|
||||||
|
"name": name,
|
||||||
|
"path": posixpath.join(path.rstrip("/"), name) if path != "/" else name,
|
||||||
|
"is_dir": is_dir,
|
||||||
|
"size": size,
|
||||||
|
"permissions": perms,
|
||||||
|
"mtime": mtime,
|
||||||
|
})
|
||||||
|
|
||||||
|
await ctx.debug(f"目录项数量:{len(items)}")
|
||||||
|
return items
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
msg = f"目录不存在或不可访问:{path}"
|
||||||
|
await ctx.error(msg)
|
||||||
|
return [{"error": msg}]
|
||||||
|
except Exception as e:
|
||||||
|
msg = f"list_files 失败:{e}"
|
||||||
|
await ctx.error(msg)
|
||||||
|
return [{"error": msg}]
|
||||||
|
|
||||||
|
@system_mcp.tool()
|
||||||
|
async def read_file(ctx: Context[ServerSession, SharedAppContext], file_path: str, encoding: str = "utf-8") -> str:
|
||||||
|
"""
|
||||||
|
读取远程沙箱内指定文件内容。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
app_ctx = ctx.request_context.lifespan_context
|
||||||
|
conn = app_ctx.ssh_connection
|
||||||
|
sandbox_root = app_ctx.sandbox_path
|
||||||
|
|
||||||
|
target = _safe_join(sandbox_root, file_path)
|
||||||
|
await ctx.info(f"读取文件:{target}")
|
||||||
|
|
||||||
|
async with conn.start_sftp_client() as sftp:
|
||||||
|
async with sftp.open(target, "rb") as f:
|
||||||
|
data = await f.read()
|
||||||
|
try:
|
||||||
|
content = data.decode(encoding)
|
||||||
|
except Exception:
|
||||||
|
content = data.decode(encoding, errors="replace")
|
||||||
|
return content
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
msg = f"读取失败:文件不存在 '{file_path}'"
|
||||||
|
await ctx.error(msg)
|
||||||
|
return msg
|
||||||
|
except Exception as e:
|
||||||
|
msg = f"read_file 失败:{e}"
|
||||||
|
await ctx.error(msg)
|
||||||
|
return msg
|
||||||
|
|
||||||
|
@system_mcp.tool()
|
||||||
|
async def write_file(
|
||||||
|
ctx: Context[ServerSession, SharedAppContext],
|
||||||
|
file_path: str,
|
||||||
|
content: str,
|
||||||
|
encoding: str = "utf-8",
|
||||||
|
create_parents: bool = True,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
写入远程沙箱文件(默认按需创建父目录)。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
app_ctx = ctx.request_context.lifespan_context # 类型化的 SharedAppContext
|
||||||
|
conn = app_ctx.ssh_connection
|
||||||
|
sandbox_root = app_ctx.sandbox_path
|
||||||
|
|
||||||
|
target = _safe_join(sandbox_root, file_path)
|
||||||
|
await ctx.info(f"写入文件:{target}")
|
||||||
|
|
||||||
|
if create_parents:
|
||||||
|
parent = posixpath.dirname(target)
|
||||||
|
if parent and parent != sandbox_root:
|
||||||
|
await conn.run(f"mkdir -p {shell_quote(parent)}", check=True)
|
||||||
|
|
||||||
|
data = content.encode(encoding)
|
||||||
|
async with conn.start_sftp_client() as sftp:
|
||||||
|
async with sftp.open(target, "wb") as f:
|
||||||
|
await f.write(data)
|
||||||
|
|
||||||
|
await ctx.debug(f"写入完成:{len(data)} 字节")
|
||||||
|
return {"path": file_path, "bytes_written": len(data)}
|
||||||
|
except Exception as e:
|
||||||
|
msg = f"write_file 失败:{e}"
|
||||||
|
await ctx.error(msg)
|
||||||
|
return {"error": msg}
|
||||||
|
|
||||||
|
@system_mcp.tool()
|
||||||
|
async def make_dir(ctx: Context[ServerSession, SharedAppContext], path: str, parents: bool = True) -> str:
|
||||||
|
try:
|
||||||
|
app_ctx = ctx.request_context.lifespan_context
|
||||||
|
conn = app_ctx.ssh_connection
|
||||||
|
sandbox_root = app_ctx.sandbox_path
|
||||||
|
|
||||||
|
target = _safe_join(sandbox_root, path)
|
||||||
|
|
||||||
|
if parents:
|
||||||
|
# 单行命令:mkdir -p
|
||||||
|
await conn.run(f"mkdir -p {shell_quote(target)}", check=True)
|
||||||
|
else:
|
||||||
|
async with conn.start_sftp_client() as sftp:
|
||||||
|
await sftp.mkdir(target)
|
||||||
|
|
||||||
|
await ctx.info(f"目录已创建: {target}")
|
||||||
|
return f"目录已创建: {path}"
|
||||||
|
except Exception as e:
|
||||||
|
await ctx.error(f"创建目录失败: {e}")
|
||||||
|
return f"错误: {e}"
|
||||||
|
|
||||||
|
@system_mcp.tool()
|
||||||
|
async def delete_file(ctx: Context[ServerSession, SharedAppContext], file_path: str) -> str:
|
||||||
|
"""
|
||||||
|
删除文件(非目录)。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
app_ctx = ctx.request_context.lifespan_context
|
||||||
|
conn = app_ctx.ssh_connection
|
||||||
|
sandbox_root = app_ctx.sandbox_path
|
||||||
|
|
||||||
|
target = _safe_join(sandbox_root, file_path)
|
||||||
|
async with conn.start_sftp_client() as sftp:
|
||||||
|
await sftp.remove(target)
|
||||||
|
|
||||||
|
await ctx.info(f"文件已删除: {target}")
|
||||||
|
return f"文件已删除: {file_path}"
|
||||||
|
except FileNotFoundError:
|
||||||
|
msg = f"删除失败:文件不存在 '{file_path}'"
|
||||||
|
await ctx.error(msg)
|
||||||
|
return msg
|
||||||
|
except Exception as e:
|
||||||
|
await ctx.error(f"删除文件失败: {e}")
|
||||||
|
return f"错误: {e}"
|
||||||
|
|
||||||
|
@system_mcp.tool()
|
||||||
|
async def delete_dir(
|
||||||
|
ctx: Context[ServerSession, SharedAppContext],
|
||||||
|
dir_path: str,
|
||||||
|
recursive: bool = False,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
删除目录。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
app_ctx = ctx.request_context.lifespan_context # 类型化的 SharedAppContext
|
||||||
|
conn = app_ctx.ssh_connection
|
||||||
|
sandbox_root = app_ctx.sandbox_path
|
||||||
|
|
||||||
|
target = _safe_join(sandbox_root, dir_path)
|
||||||
|
|
||||||
|
if recursive:
|
||||||
|
await conn.run(f"rm -rf {shell_quote(target)}", check=True)
|
||||||
|
else:
|
||||||
|
async with conn.start_sftp_client() as sftp:
|
||||||
|
await sftp.rmdir(target)
|
||||||
|
|
||||||
|
await ctx.info(f"目录已删除: {target}")
|
||||||
|
return f"目录已删除: {dir_path}"
|
||||||
|
except FileNotFoundError:
|
||||||
|
msg = f"删除失败:目录不存在 '{dir_path}' 或非空"
|
||||||
|
await ctx.error(msg)
|
||||||
|
return msg
|
||||||
|
except Exception as e:
|
||||||
|
await ctx.error(f"删除目录失败: {e}")
|
||||||
|
return f"错误: {e}"
|
||||||
|
|
||||||
|
@system_mcp.tool()
|
||||||
|
async def move_path(ctx: Context[ServerSession, SharedAppContext], src: str, dst: str,
|
||||||
|
overwrite: bool = True) -> str:
|
||||||
|
try:
|
||||||
|
app_ctx = ctx.request_context.lifespan_context
|
||||||
|
conn = app_ctx.ssh_connection
|
||||||
|
sandbox_root = app_ctx.sandbox_path
|
||||||
|
|
||||||
|
src_abs = _safe_join(sandbox_root, src)
|
||||||
|
dst_abs = _safe_join(sandbox_root, dst)
|
||||||
|
|
||||||
|
flag = "-f" if overwrite else ""
|
||||||
|
# 单行命令:mv
|
||||||
|
cmd = f"mv {flag} {shell_quote(src_abs)} {shell_quote(dst_abs)}".strip()
|
||||||
|
await conn.run(cmd, check=True)
|
||||||
|
|
||||||
|
await ctx.info(f"已移动: {src_abs} -> {dst_abs}")
|
||||||
|
return f"已移动: {src} -> {dst}"
|
||||||
|
except FileNotFoundError:
|
||||||
|
msg = f"移动失败:源不存在 '{src}'"
|
||||||
|
await ctx.error(msg)
|
||||||
|
return msg
|
||||||
|
except Exception as e:
|
||||||
|
await ctx.error(f"移动失败: {e}")
|
||||||
|
return f"错误: {e}"
|
||||||
|
|
||||||
|
@system_mcp.tool()
|
||||||
|
async def copy_path(
|
||||||
|
ctx: Context[ServerSession, SharedAppContext],
|
||||||
|
src: str,
|
||||||
|
dst: str,
|
||||||
|
recursive: bool = True,
|
||||||
|
overwrite: bool = True,
|
||||||
|
) -> str:
|
||||||
|
try:
|
||||||
|
app_ctx = ctx.request_context.lifespan_context
|
||||||
|
conn = app_ctx.ssh_connection
|
||||||
|
sandbox_root = app_ctx.sandbox_path
|
||||||
|
|
||||||
|
src_abs = _safe_join(sandbox_root, src)
|
||||||
|
dst_abs = _safe_join(sandbox_root, dst)
|
||||||
|
|
||||||
|
flags = []
|
||||||
|
if recursive:
|
||||||
|
flags.append("-r")
|
||||||
|
if overwrite:
|
||||||
|
flags.append("-f")
|
||||||
|
|
||||||
|
# 单行命令:cp
|
||||||
|
cmd = " ".join(["cp"] + flags + [shell_quote(src_abs), shell_quote(dst_abs)])
|
||||||
|
await conn.run(cmd, check=True)
|
||||||
|
|
||||||
|
await ctx.info(f"已复制: {src_abs} -> {dst_abs}")
|
||||||
|
return f"已复制: {src} -> {dst}"
|
||||||
|
except FileNotFoundError:
|
||||||
|
msg = f"复制失败:源不存在 '{src}'"
|
||||||
|
await ctx.error(msg)
|
||||||
|
return msg
|
||||||
|
except Exception as e:
|
||||||
|
await ctx.error(f"复制失败: {e}")
|
||||||
|
return f"错误: {e}"
|
||||||
|
|
||||||
|
|
||||||
|
@system_mcp.tool()
|
||||||
|
async def exists(ctx: Context[ServerSession, SharedAppContext], path: str) -> bool:
|
||||||
|
"""
|
||||||
|
判断路径(文件/目录)是否存在。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
app_ctx = ctx.request_context.lifespan_context
|
||||||
|
conn = app_ctx.ssh_connection
|
||||||
|
sandbox_root = app_ctx.sandbox_path
|
||||||
|
|
||||||
|
target = _safe_join(sandbox_root, path)
|
||||||
|
async with conn.start_sftp_client() as sftp:
|
||||||
|
await sftp.stat(target)
|
||||||
|
return True
|
||||||
|
except FileNotFoundError:
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
await ctx.error(f"exists 检查失败: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
@system_mcp.tool()
|
||||||
|
async def stat_path(ctx: Context[ServerSession, SharedAppContext], path: str) -> dict:
|
||||||
|
"""
|
||||||
|
查看远程路径属性(结构化输出)。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
app_ctx = ctx.request_context.lifespan_context
|
||||||
|
conn = app_ctx.ssh_connection
|
||||||
|
sandbox_root = app_ctx.sandbox_path
|
||||||
|
|
||||||
|
target = _safe_join(sandbox_root, path)
|
||||||
|
async with conn.start_sftp_client() as sftp:
|
||||||
|
attrs = await sftp.stat(target)
|
||||||
|
|
||||||
|
perms = attrs.permissions or 0
|
||||||
|
return {
|
||||||
|
"path": target,
|
||||||
|
"size": int(attrs.size or 0),
|
||||||
|
"is_dir": bool(pystat.S_ISDIR(perms)),
|
||||||
|
"permissions": perms, # 九进制权限位,例:0o755
|
||||||
|
"mtime": int(attrs.mtime or 0), # 秒
|
||||||
|
}
|
||||||
|
except FileNotFoundError:
|
||||||
|
msg = f"路径不存在: {path}"
|
||||||
|
await ctx.error(msg)
|
||||||
|
return {"error": msg}
|
||||||
|
except Exception as e:
|
||||||
|
await ctx.error(f"stat 失败: {e}")
|
||||||
|
return {"error": str(e)}
|
||||||
|
|
||||||
|
return system_mcp
|
||||||
19
mcp/test_tools.py
Normal file
19
mcp/test_tools.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
# test_tools.py
|
||||||
|
from mcp.server.fastmcp import FastMCP
|
||||||
|
|
||||||
|
|
||||||
|
def create_test_mcp() -> FastMCP:
|
||||||
|
"""创建一个只包含最简单工具的 MCP 实例,用于测试连接。"""
|
||||||
|
|
||||||
|
test_mcp = FastMCP(
|
||||||
|
name="Test Tools",
|
||||||
|
instructions="用于测试服务器连接是否通畅。",
|
||||||
|
streamable_http_path="/"
|
||||||
|
)
|
||||||
|
|
||||||
|
@test_mcp.tool()
|
||||||
|
def ping() -> str:
|
||||||
|
"""一个简单的工具,用于确认服务器是否响应。"""
|
||||||
|
return "pong"
|
||||||
|
|
||||||
|
return test_mcp
|
||||||
185
mcp/topological_analysis_models.py
Normal file
185
mcp/topological_analysis_models.py
Normal file
@@ -0,0 +1,185 @@
|
|||||||
|
# topological_analysis_tools.py
|
||||||
|
|
||||||
|
import posixpath
|
||||||
|
import re
|
||||||
|
from typing import Any, Literal
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
|
||||||
|
import asyncssh
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
# 从 MCP SDK 导入核心组件 [1]
|
||||||
|
from mcp.server.fastmcp import FastMCP, Context
|
||||||
|
from mcp.server.session import ServerSession
|
||||||
|
|
||||||
|
# 从 lifespan.py 导入共享的配置和数据类 [3]
|
||||||
|
# 注意: 我们不再需要导入 system_lifespan
|
||||||
|
from lifespan import (
|
||||||
|
SharedAppContext,
|
||||||
|
REMOTE_HOST,
|
||||||
|
REMOTE_USER,
|
||||||
|
PRIVATE_KEY_PATH,
|
||||||
|
INITIAL_WORKING_DIRECTORY,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# 1. 辅助函数 (为实现独立性而复制)
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
def shell_quote(arg: str) -> str:
|
||||||
|
"""安全地将字符串转义为单个 shell 参数 [5]。"""
|
||||||
|
return "'" + arg.replace("'", "'\"'\"'") + "'"
|
||||||
|
|
||||||
|
def _safe_join(sandbox_root: str, relative_path: str) -> str:
|
||||||
|
"""安全地将用户提供的相对路径连接到沙箱根目录 [5]。"""
|
||||||
|
rel = (relative_path or ".").strip()
|
||||||
|
if rel.startswith("/"):
|
||||||
|
rel = rel.lstrip("/")
|
||||||
|
combined = posixpath.normpath(posixpath.join(sandbox_root, rel))
|
||||||
|
root_norm = sandbox_root.rstrip("/")
|
||||||
|
if combined != root_norm and not combined.startswith(root_norm + "/"):
|
||||||
|
raise ValueError("路径越界: 仅允许访问沙箱目录内部")
|
||||||
|
if ".." in combined.split("/"):
|
||||||
|
raise ValueError("非法路径: 不允许使用 '..' 跨目录")
|
||||||
|
return combined
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# 2. 独立的生命周期管理器
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def analysis_lifespan(_server: FastMCP) -> AsyncIterator[SharedAppContext]:
|
||||||
|
"""
|
||||||
|
为拓扑分析工具独立管理 SSH 连接的生命周期。
|
||||||
|
"""
|
||||||
|
conn: asyncssh.SSHClientConnection | None = None
|
||||||
|
print("分析工具: 正在建立独立的 SSH 连接...")
|
||||||
|
try:
|
||||||
|
conn = await asyncssh.connect(
|
||||||
|
REMOTE_HOST,
|
||||||
|
username=REMOTE_USER,
|
||||||
|
client_keys=[PRIVATE_KEY_PATH],
|
||||||
|
)
|
||||||
|
print(f"分析工具: 独立的 SSH 连接到 {REMOTE_HOST} 成功!")
|
||||||
|
yield SharedAppContext(ssh_connection=conn, sandbox_path=INITIAL_WORKING_DIRECTORY)
|
||||||
|
finally:
|
||||||
|
if conn:
|
||||||
|
conn.close()
|
||||||
|
await conn.wait_closed()
|
||||||
|
print("分析工具: 独立的 SSH 连接已关闭。")
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# 3. 数据模型与解析函数 (与之前相同)
|
||||||
|
# ==============================================================================
|
||||||
|
# (为简洁起见,此处省略了 Pydantic 模型和 _parse_analysis_output 函数的定义,
|
||||||
|
# 请将上一版本中的这部分代码粘贴到这里)
|
||||||
|
VALID_CONFIGS = Literal["O.yaml", "S.yaml", "Cl.yaml", "Br.yaml"]
|
||||||
|
class SubmitTopologicalAnalysisParams(BaseModel):
|
||||||
|
cif_file_path: str = Field(description="远程服务器上 CIF 文件的相对路径。")
|
||||||
|
config_files: list[VALID_CONFIGS] = Field(description="选择一个或多个用于分析的 YAML 配置文件。")
|
||||||
|
output_file_path: str = Field(description="用于保存计算结果的输出文件相对路径。")
|
||||||
|
class AnalysisResult(BaseModel):
|
||||||
|
percolation_diameter_a: float | None = Field(None, description="渗透直径 (Percolation diameter), 单位 Å。")
|
||||||
|
connectivity_distance: float | None = Field(None, description="连通距离 (the minium of d), 单位 Å。")
|
||||||
|
max_node_length_a: float | None = Field(None, description="最大节点长度 (Maximum node length), 单位 Å。")
|
||||||
|
min_cation_distance_a: float | None = Field(None, description="到阳离子的最小距离, 单位 Å。")
|
||||||
|
total_time_s: float | None = Field(None, description="总计算耗时, 单位秒。")
|
||||||
|
long_node_count: int | None = Field(None, description="长节点数量 (Long node number)。")
|
||||||
|
short_node_count: int | None = Field(None, description="短节点数量 (Short node number)。")
|
||||||
|
warnings: list[str] = Field(default_factory=list, description="解析过程中遇到的警告或缺失的关键值信息。")
|
||||||
|
def _parse_analysis_output(log_content: str) -> AnalysisResult:
|
||||||
|
def find_float(pattern: str) -> float | None:
|
||||||
|
match = re.search(pattern, log_content)
|
||||||
|
return float(match.group(1)) if match else None
|
||||||
|
def find_int(pattern: str) -> int | None:
|
||||||
|
match = re.search(pattern, log_content)
|
||||||
|
return int(match.group(1)) if match else None
|
||||||
|
data = { "percolation_diameter_a": find_float(r"Percolation diameter \(A\):\s*([\d\.]+)"), "max_node_length_a": find_float(r"Maximum node length detected:\s*([\d\.]+)\s*A"), "min_cation_distance_a": find_float(r"minimum distance to cations is\s*([\d\.]+)\s*A"), "total_time_s": find_float(r"Total used time:\s*([\d\.]+)"), "long_node_count": find_int(r"Long node number:\s*(\d+)"), "short_node_count": find_int(r"Short node number:\s*(\d+)"), }
|
||||||
|
conn_dist_match = re.search(r"the minium of d\s*([\d\.]+)", log_content, re.MULTILINE)
|
||||||
|
data["connectivity_distance"] = float(conn_dist_match.group(1)) if conn_dist_match else None
|
||||||
|
warnings = [k + " 未找到" for k, v in data.items() if v is None and k not in ["min_cation_distance_a", "long_node_count", "short_node_count"]]
|
||||||
|
return AnalysisResult(**data, warnings=warnings)
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# 4. MCP 服务器工厂函数
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
def create_topological_analysis_mcp() -> FastMCP:
|
||||||
|
"""
|
||||||
|
供 Starlette 挂载的工厂函数。
|
||||||
|
这个函数现在创建并使用自己独立的 lifespan 管理器。
|
||||||
|
"""
|
||||||
|
analysis_mcp = FastMCP(
|
||||||
|
name="Topological Analysis Tools",
|
||||||
|
instructions="用于在远程服务器上提交和分析拓扑计算任务的工具集。",
|
||||||
|
streamable_http_path="/",
|
||||||
|
lifespan=analysis_lifespan, # 关键: 使用本文件内定义的独立 lifespan
|
||||||
|
)
|
||||||
|
|
||||||
|
@analysis_mcp.tool()
|
||||||
|
async def submit_topological_analysis(
|
||||||
|
params: SubmitTopologicalAnalysisParams,
|
||||||
|
ctx: Context[ServerSession, SharedAppContext],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""【步骤1/2】异步提交拓扑分析任务。"""
|
||||||
|
# ... (工具的实现逻辑与之前完全相同)
|
||||||
|
try:
|
||||||
|
app_ctx = ctx.request_context.lifespan_context
|
||||||
|
conn = app_ctx.ssh_connection
|
||||||
|
sandbox_root = app_ctx.sandbox_path
|
||||||
|
home_dir = f"/cluster/home/{REMOTE_USER}"
|
||||||
|
tool_dir = f"{home_dir}/tool/Topological_analyse"
|
||||||
|
cif_abs_path = _safe_join(sandbox_root, params.cif_file_path)
|
||||||
|
output_abs_path = _safe_join(sandbox_root, params.output_file_path)
|
||||||
|
config_args = " ".join([shell_quote(f"{tool_dir}/{cfg}") for cfg in params.config_files])
|
||||||
|
config_flags = f"-i {config_args}"
|
||||||
|
command = f"""
|
||||||
|
nohup sh -c '
|
||||||
|
source {shell_quote(f"{home_dir}/.bashrc")} && \\
|
||||||
|
conda activate {shell_quote(f"{home_dir}/anaconda3/envs/zeo")} && \\
|
||||||
|
python {shell_quote(f"{tool_dir}/analyze_voronoi_nodes.py")} \\
|
||||||
|
{shell_quote(cif_abs_path)} \\
|
||||||
|
{config_flags}
|
||||||
|
' > {shell_quote(output_abs_path)} 2>&1 &
|
||||||
|
""".strip()
|
||||||
|
await ctx.info("提交后台拓扑分析任务...")
|
||||||
|
await conn.run(command, check=True)
|
||||||
|
return {"success": True, "message": "拓扑分析任务已成功提交。", "output_file": params.output_file_path}
|
||||||
|
except Exception as e:
|
||||||
|
msg = f"提交后台任务失败: {e}"
|
||||||
|
await ctx.error(msg)
|
||||||
|
return {"success": False, "error": msg}
|
||||||
|
|
||||||
|
@analysis_mcp.tool()
|
||||||
|
async def analyze_topological_results(
|
||||||
|
output_file_path: str,
|
||||||
|
ctx: Context[ServerSession, SharedAppContext],
|
||||||
|
) -> AnalysisResult:
|
||||||
|
"""【步骤2/2】读取并分析拓扑分析任务的输出文件。"""
|
||||||
|
# ... (工具的实现逻辑与之前完全相同)
|
||||||
|
try:
|
||||||
|
await ctx.info(f"开始分析结果文件: {output_file_path}")
|
||||||
|
app_ctx = ctx.request_context.lifespan_context
|
||||||
|
conn = app_ctx.ssh_connection
|
||||||
|
sandbox_root = app_ctx.sandbox_path
|
||||||
|
target_path = _safe_join(sandbox_root, output_file_path)
|
||||||
|
async with conn.start_sftp_client() as sftp:
|
||||||
|
async with sftp.open(target_path, "r") as f:
|
||||||
|
log_content = await f.read()
|
||||||
|
if not log_content:
|
||||||
|
raise ValueError("结果文件为空。")
|
||||||
|
analysis_data = _parse_analysis_output(log_content)
|
||||||
|
return analysis_data
|
||||||
|
except FileNotFoundError:
|
||||||
|
msg = f"分析失败: 结果文件 '{output_file_path}' 不存在。"
|
||||||
|
await ctx.error(msg)
|
||||||
|
return AnalysisResult(warnings=[msg])
|
||||||
|
except Exception as e:
|
||||||
|
msg = f"分析结果时发生错误: {e}"
|
||||||
|
await ctx.error(msg)
|
||||||
|
return AnalysisResult(warnings=[msg])
|
||||||
|
|
||||||
|
return analysis_mcp
|
||||||
10
mcp/uvicorn/main.py
Normal file
10
mcp/uvicorn/main.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
|
||||||
|
# main.py
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
async def read_root():
|
||||||
|
return {"message": "Hello, World!"}
|
||||||
Reference in New Issue
Block a user