Compare commits
19 Commits
76105a631d
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
| 95d719cc1e | |||
| 1f8667ae51 | |||
| c0b2ec5983 | |||
| b9ba79d7a8 | |||
| bd4bd3a645 | |||
| c20d752faa | |||
| 4e21954471 | |||
| fc7716507a | |||
| b11bf2417b | |||
| 9b0f835575 | |||
| b19352382b | |||
| 28639f9cbf | |||
| 41a6038e50 | |||
| f625154aee | |||
| f585d76cac | |||
| c8629619ee | |||
| 5d1a4d04f2 | |||
| e6141689c1 | |||
| efcdacffd0 |
95
Screen/process_txt.py
Normal file
95
Screen/process_txt.py
Normal file
@@ -0,0 +1,95 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import csv
|
||||
|
||||
|
||||
def extract_data_from_folder(folder_path):
|
||||
"""
|
||||
读取文件夹中所有txt文件,提取指定信息,并写入一个CSV文件。
|
||||
(已修正正则表达式)
|
||||
|
||||
:param folder_path: 包含txt文件的文件夹路径。
|
||||
"""
|
||||
# 检查文件夹是否存在
|
||||
if not os.path.isdir(folder_path):
|
||||
print(f"错误:文件夹 '{folder_path}' 不存在。")
|
||||
return
|
||||
|
||||
# 定义要提取数据的正则表达式
|
||||
# 1. 对应 "Percolation diameter (A): 1.234"
|
||||
pattern1 = re.compile(r"Percolation diameter \(A\): ([\d\.]+)")
|
||||
|
||||
# 2. 对应 "the minium of d 1.23 #" (已根据您的反馈修正)
|
||||
pattern2 = re.compile(r"the minium of d\s*([\d\.]+)\s*#")
|
||||
|
||||
# 3. 对应 "Maximum node length detected: 5.67 A"
|
||||
pattern3 = re.compile(r"Maximum node length detected: ([\d\.]+) A")
|
||||
|
||||
# 存储所有提取到的数据
|
||||
all_data = []
|
||||
|
||||
# 遍历文件夹中的所有文件,使用 sorted() 确保处理顺序一致
|
||||
for filename in sorted(os.listdir(folder_path)):
|
||||
if filename.endswith(".txt"):
|
||||
txt_path = os.path.join(folder_path, filename)
|
||||
|
||||
try:
|
||||
with open(txt_path, 'r', encoding='utf-8') as file:
|
||||
content = file.read()
|
||||
|
||||
# 使用修正后的正则表达式查找数据
|
||||
match1 = pattern1.search(content)
|
||||
match2 = pattern2.search(content)
|
||||
match3 = pattern3.search(content)
|
||||
|
||||
# 提取匹配到的值,如果未匹配到则为空字符串 ''
|
||||
val1 = match1.group(1) if match1 else ''
|
||||
val2 = match2.group(1) if match2 else ''
|
||||
val3 = match3.group(1) if match3 else ''
|
||||
|
||||
# 获取文件名(不含.txt后缀)
|
||||
base_filename = os.path.splitext(filename)[0]
|
||||
|
||||
# 将这一行的数据添加到总列表中
|
||||
all_data.append([base_filename, val1, val2, val3])
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理文件 {filename} 时发生错误: {e}")
|
||||
|
||||
# 如果没有找到任何txt文件或数据,则不创建csv
|
||||
if not all_data:
|
||||
print("未在文件夹中找到任何 .txt 文件或未能提取任何数据。")
|
||||
return
|
||||
|
||||
# 根据文件夹名确定CSV文件名
|
||||
folder_name = os.path.basename(os.path.normpath(folder_path))
|
||||
csv_filename = f"{folder_name}.csv"
|
||||
|
||||
# 写入CSV文件
|
||||
try:
|
||||
with open(csv_filename, 'w', newline='', encoding='utf-8') as csvfile:
|
||||
writer = csv.writer(csvfile)
|
||||
|
||||
# 写入表头
|
||||
headers = ['filename', 'Percolation_diameter_A', 'minium_of_d', 'Max_node_length_A']
|
||||
writer.writerow(headers)
|
||||
|
||||
# 写入所有数据
|
||||
writer.writerows(all_data)
|
||||
|
||||
print(f"数据成功写入到文件: {csv_filename}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"写入CSV文件 {csv_filename} 时发生错误: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 2:
|
||||
print("用法: python your_script_name.py <folder_path>")
|
||||
sys.exit(1)
|
||||
|
||||
input_folder = sys.argv[1]
|
||||
extract_data_from_folder(input_folder)
|
||||
151
contrast learning/copy.py
Normal file
151
contrast learning/copy.py
Normal file
@@ -0,0 +1,151 @@
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def find_element_column_index(cif_lines: list) -> int:
|
||||
"""
|
||||
在CIF文件内容中查找 _atom_site_type_symbol 所在的列索引。
|
||||
|
||||
:param cif_lines: 从CIF文件读取的行列表。
|
||||
:return: 元素符号列的索引(从0开始),如果未找到则返回-1。
|
||||
"""
|
||||
in_loop_header = False
|
||||
column_index = -1
|
||||
current_column = 0
|
||||
|
||||
for line in cif_lines:
|
||||
line_stripped = line.strip()
|
||||
if not line_stripped:
|
||||
continue
|
||||
|
||||
if line_stripped.startswith('loop_'):
|
||||
in_loop_header = True
|
||||
column_index = -1
|
||||
current_column = 0
|
||||
continue
|
||||
|
||||
if in_loop_header:
|
||||
if line_stripped.startswith('_'):
|
||||
if line_stripped.startswith('_atom_site_type_symbol'):
|
||||
column_index = current_column
|
||||
current_column += 1
|
||||
else:
|
||||
# loop_ 头部定义结束,开始数据行
|
||||
return column_index
|
||||
|
||||
return -1 # 如果文件中没有找到 loop_ 或 _atom_site_type_symbol
|
||||
|
||||
|
||||
def copy_cif_with_O_or_S_robust(source_dir: str, target_dir: str, dry_run: bool = False):
|
||||
"""
|
||||
从源文件夹中筛选出内容包含'O'或'S'元素的CIF文件,并复制到目标文件夹。
|
||||
(鲁棒版:能正确解析CIF中的元素符号列)
|
||||
|
||||
:param source_dir: 源文件夹路径,包含CIF文件。
|
||||
:param target_dir: 目标文件夹路径,用于存放筛选出的文件。
|
||||
:param dry_run: 如果为True,则只打印将要复制的文件,而不实际执行复制操作。
|
||||
"""
|
||||
# 1. 路径处理和验证
|
||||
source_path = Path(source_dir)
|
||||
target_path = Path(target_dir)
|
||||
|
||||
if not source_path.is_dir():
|
||||
print(f"错误:源文件夹 '{source_dir}' 不存在或不是一个文件夹。")
|
||||
return
|
||||
|
||||
if not dry_run and not target_path.exists():
|
||||
target_path.mkdir(parents=True, exist_ok=True)
|
||||
print(f"目标文件夹 '{target_dir}' 已创建。")
|
||||
|
||||
print(f"源文件夹: {source_path}")
|
||||
print(f"目标文件夹: {target_path}")
|
||||
if dry_run:
|
||||
print("\n--- *** 模拟运行模式 (Dry Run) *** ---")
|
||||
print("--- 不会执行任何实际的文件复制操作 ---")
|
||||
|
||||
# 2. 开始遍历和筛选
|
||||
print("\n开始扫描源文件夹中的CIF文件...")
|
||||
copied_count = 0
|
||||
checked_files = 0
|
||||
error_files = 0
|
||||
|
||||
# 使用 rglob('*.cif') 可以遍历所有子文件夹,如果只想遍历当前文件夹用 glob
|
||||
for file_path in source_path.glob('*.cif'):
|
||||
if file_path.is_file():
|
||||
checked_files += 1
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# 步骤 A: 找到元素符号在哪一列
|
||||
element_col_idx = find_element_column_index(lines)
|
||||
|
||||
if element_col_idx == -1:
|
||||
# 在某些CIF文件中,可能没有loop块,而是简单的 key-value 格式
|
||||
# 为了兼容这种情况,我们保留一个简化的检查
|
||||
found_simple = any(
|
||||
line.strip().startswith(('_chemical_formula_sum', '_chemical_formula_structural')) and (
|
||||
' O' in line or ' S' in line) for line in lines)
|
||||
if not found_simple:
|
||||
continue # 如果两种方法都找不到,跳过此文件
|
||||
|
||||
# 步骤 B: 检查该列是否有 'O' 或 'S'
|
||||
found = False
|
||||
for line in lines:
|
||||
line_stripped = line.strip()
|
||||
# 忽略空行、注释行和定义行
|
||||
if not line_stripped or line_stripped.startswith(('#', '_', 'loop_')):
|
||||
continue
|
||||
|
||||
parts = line_stripped.split()
|
||||
# 确保行中有足够的列
|
||||
if len(parts) > element_col_idx:
|
||||
# 元素符号可能带有电荷,如 O2-,所以用 startswith
|
||||
atom_symbol = parts[element_col_idx].strip()
|
||||
if atom_symbol == 'O' or atom_symbol == 'S':
|
||||
found = True
|
||||
break
|
||||
|
||||
# 兼容性检查:如果通过了 found_simple 的检查,也标记为找到
|
||||
if found_simple:
|
||||
found = True
|
||||
|
||||
if found:
|
||||
target_file_path = target_path / file_path.name
|
||||
print(f"找到匹配: '{file_path.name}' (含有 O 或 S 元素)")
|
||||
|
||||
if not dry_run:
|
||||
shutil.copy2(file_path, target_file_path)
|
||||
# print(f" -> 已复制到 {target_file_path}") # 可以取消注释以获得更详细的输出
|
||||
|
||||
copied_count += 1
|
||||
|
||||
except Exception as e:
|
||||
error_files += 1
|
||||
print(f"!! 处理文件 '{file_path.name}' 时发生错误: {e}")
|
||||
|
||||
# 3. 打印最终报告
|
||||
print("\n--- 操作总结 ---")
|
||||
print(f"共检查了 {checked_files} 个.cif文件。")
|
||||
if error_files > 0:
|
||||
print(f"处理过程中有 {error_files} 个文件发生错误。")
|
||||
if dry_run:
|
||||
print(f"模拟运行结束:如果实际运行,将会有 {copied_count} 个文件被复制。")
|
||||
else:
|
||||
print(f"成功复制了 {copied_count} 个文件到目标文件夹。")
|
||||
if __name__ == '__main__':
|
||||
# !! 重要:请将下面的路径修改为您自己电脑上的实际路径
|
||||
source_folder = "D:/download/2025-10/data_all/input/input"
|
||||
target_folder = "D:/download/2025-10/data_all/output"
|
||||
|
||||
# --- 第一次运行:使用模拟模式 (Dry Run) ---
|
||||
print("================ 第一次运行: 模拟模式 ================")
|
||||
copy_cif_with_O_or_S_robust(source_folder, target_folder, dry_run=True)
|
||||
|
||||
print("\n\n=======================================================")
|
||||
input("检查上面的模拟运行结果。如果符合预期,按回车键继续执行实际复制操作...")
|
||||
print("=======================================================")
|
||||
|
||||
# --- 第二次运行:实际执行复制 ---
|
||||
print("\n================ 第二次运行: 实际复制模式 ================")
|
||||
copy_cif_with_O_or_S_robust(source_folder, target_folder, dry_run=False)
|
||||
111
contrast learning/delete.py
Normal file
111
contrast learning/delete.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def delete_duplicates_from_second_folder(source_dir: str, target_dir: str, dry_run: bool = False):
|
||||
"""
|
||||
删除第二个文件夹中与第一个文件夹内项目同名的文件或文件夹。
|
||||
|
||||
:param source_dir: 第一个文件夹(源)的路径。
|
||||
:param target_dir: 第二个文件夹(目标)的路径,将从此文件夹中删除内容。
|
||||
:param dry_run: 如果为True,则只打印将要删除的内容,而不实际执行删除操作。
|
||||
"""
|
||||
# 1. 将字符串路径转换为Path对象,方便操作
|
||||
source_path = Path(source_dir)
|
||||
target_path = Path(target_dir)
|
||||
|
||||
# 2. 验证路径是否存在且为文件夹
|
||||
if not source_path.is_dir():
|
||||
print(f"错误:源文件夹 '{source_dir}' 不存在或不是一个文件夹。")
|
||||
return
|
||||
if not target_path.is_dir():
|
||||
print(f"错误:目标文件夹 '{target_dir}' 不存在或不是一个文件夹。")
|
||||
return
|
||||
|
||||
print(f"源文件夹: {source_path}")
|
||||
print(f"目标文件夹: {target_path}")
|
||||
if dry_run:
|
||||
print("\n--- *** 模拟运行模式 (Dry Run) *** ---")
|
||||
print("--- 不会执行任何实际的删除操作 ---")
|
||||
|
||||
# 3. 获取源文件夹中所有项目(文件和子文件夹)的名称
|
||||
# p.name 会返回路径的最后一部分,即文件名或文件夹名
|
||||
source_item_names = {p.name for p in source_path.iterdir()}
|
||||
|
||||
if not source_item_names:
|
||||
print("\n源文件夹为空,无需执行任何操作。")
|
||||
return
|
||||
|
||||
print(f"\n在源文件夹中找到 {len(source_item_names)} 个项目。")
|
||||
print("开始检查并删除目标文件夹中的同名项目...")
|
||||
|
||||
deleted_count = 0
|
||||
# 4. 遍历源文件夹中的项目名称
|
||||
for item_name in source_item_names:
|
||||
# 构建目标文件夹中可能存在的同名项目的完整路径
|
||||
item_to_delete = target_path / item_name
|
||||
|
||||
# 5. 检查该项目是否存在于目标文件夹中
|
||||
if item_to_delete.exists():
|
||||
try:
|
||||
if item_to_delete.is_file():
|
||||
# 如果是文件,直接删除
|
||||
print(f"准备删除文件: {item_to_delete}")
|
||||
if not dry_run:
|
||||
item_to_delete.unlink()
|
||||
print(" -> 已删除。")
|
||||
deleted_count += 1
|
||||
|
||||
elif item_to_delete.is_dir():
|
||||
# 如果是文件夹,使用 shutil.rmtree 删除整个文件夹及其内容
|
||||
print(f"准备删除文件夹及其所有内容: {item_to_delete}")
|
||||
if not dry_run:
|
||||
shutil.rmtree(item_to_delete)
|
||||
print(" -> 已删除。")
|
||||
deleted_count += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"!! 删除 '{item_to_delete}' 时发生错误: {e}")
|
||||
|
||||
if deleted_count == 0:
|
||||
print("\n操作完成:在目标文件夹中没有找到需要删除的同名项目。")
|
||||
else:
|
||||
if dry_run:
|
||||
print(f"\n模拟运行结束:如果实际运行,将会有 {deleted_count} 个项目被删除。")
|
||||
else:
|
||||
print(f"\n操作完成:总共删除了 {deleted_count} 个项目。")
|
||||
|
||||
|
||||
# --- 使用示例 ---
|
||||
|
||||
# 在运行前,请创建以下文件夹和文件结构进行测试:
|
||||
# /your/path/folder1/
|
||||
# ├── file_a.txt
|
||||
# ├── file_b.log
|
||||
# └── subfolder_x/
|
||||
# └── test.txt
|
||||
|
||||
# /your/path/folder2/
|
||||
# ├── file_a.txt (将被删除)
|
||||
# ├── file_c.md
|
||||
# └── subfolder_x/ (将被删除)
|
||||
# └── another.txt
|
||||
|
||||
if __name__ == '__main__':
|
||||
# !! 重要:请将下面的路径修改为您自己电脑上的实际路径
|
||||
folder1_path = "D:/download/2025-10/after_step5/after_step5/S" # 源文件夹
|
||||
folder2_path = "D:/download/2025-10/input/input" # 目标文件夹
|
||||
|
||||
# --- 第一次运行:使用模拟模式 (Dry Run),非常推荐!---
|
||||
# 这会告诉你脚本将要做什么,但不会真的删除任何东西。
|
||||
print("================ 第一次运行: 模拟模式 ================")
|
||||
delete_duplicates_from_second_folder(folder1_path, folder2_path, dry_run=True)
|
||||
|
||||
print("\n\n=======================================================")
|
||||
input("检查上面的模拟运行结果。如果符合预期,按回车键继续执行实际删除操作...")
|
||||
print("=======================================================")
|
||||
|
||||
# --- 第二次运行:实际执行删除 ---
|
||||
# 确认模拟运行结果无误后,再将 dry_run 设置为 False 或移除该参数。
|
||||
print("\n================ 第二次运行: 实际删除模式 ================")
|
||||
delete_duplicates_from_second_folder(folder1_path, folder2_path, dry_run=False)
|
||||
92
corner-sharing/0923_CS.py
Normal file
92
corner-sharing/0923_CS.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import os
|
||||
import csv
|
||||
from pymatgen.core import Structure
|
||||
from tqdm import tqdm # 引入tqdm来显示进度条,如果未安装请运行 pip install tqdm
|
||||
|
||||
# --- 导入您自定义的分析函数 ---
|
||||
# 假设您的函数存放在 'utils/CS_analyse.py' 文件中
|
||||
# 并且您已经将它们重命名
|
||||
# from calculate_polyhedra_sharing import calculate_polyhedra_sharing as CS_catulate
|
||||
# from check_only_corner_sharing import check_only_corner_sharing
|
||||
# 注意:根据您的描述,您会将函数放在 utils 文件夹中,因此导入方式如下:
|
||||
from utils.CS_analyse import CS_catulate, check_only_corner_sharing
|
||||
|
||||
|
||||
def process_cif_folder(cif_folder_path: str, output_csv_path: str):
|
||||
"""
|
||||
遍历指定文件夹中的所有CIF文件,计算其角共享特性,并将结果输出到CSV文件。
|
||||
|
||||
参数:
|
||||
cif_folder_path (str): 存放CIF文件的文件夹路径。
|
||||
output_csv_path (str): 输出的CSV文件的路径。
|
||||
"""
|
||||
# 检查输入文件夹是否存在
|
||||
if not os.path.isdir(cif_folder_path):
|
||||
print(f"错误: 文件夹 '{cif_folder_path}' 不存在。")
|
||||
return
|
||||
|
||||
# 准备存储结果的列表
|
||||
results = []
|
||||
|
||||
# 获取所有CIF文件的列表
|
||||
try:
|
||||
cif_files = [f for f in os.listdir(cif_folder_path) if f.endswith('.cif')]
|
||||
if not cif_files:
|
||||
print(f"警告: 在文件夹 '{cif_folder_path}' 中没有找到任何 .cif 文件。")
|
||||
return
|
||||
except FileNotFoundError:
|
||||
print(f"错误: 无法访问文件夹 '{cif_folder_path}'。")
|
||||
return
|
||||
|
||||
print(f"开始处理 {len(cif_files)} 个CIF文件...")
|
||||
|
||||
# 使用tqdm创建进度条,遍历所有CIF文件
|
||||
for filename in tqdm(cif_files, desc="Processing CIFs"):
|
||||
file_path = os.path.join(cif_folder_path, filename)
|
||||
|
||||
try:
|
||||
# 1. 从CIF文件加载结构
|
||||
struct = Structure.from_file(file_path)
|
||||
|
||||
# 2. 调用您的 CS_catulate 函数计算详细的共享关系
|
||||
# 这里使用默认参数 sp='Li', anion=['O']
|
||||
sharing_details = CS_catulate(struct, sp='Li', anion=['O','S','Cl','F','Br'])
|
||||
|
||||
# 3. 调用 check_only_corner_sharing 函数进行最终判断
|
||||
is_only_corner = check_only_corner_sharing(sharing_details)
|
||||
|
||||
# 4. 将文件名和结果存入列表
|
||||
results.append([filename, is_only_corner])
|
||||
|
||||
except Exception as e:
|
||||
# 如果处理某个文件时出错,打印错误信息并继续处理下一个文件
|
||||
print(f"\n处理文件 '{filename}' 时发生错误: {e}")
|
||||
results.append([filename, 'Error']) # 在CSV中标记错误
|
||||
|
||||
# 5. 将结果写入CSV文件
|
||||
try:
|
||||
with open(output_csv_path, 'w', newline='', encoding='utf-8') as csvfile:
|
||||
writer = csv.writer(csvfile)
|
||||
# 写入表头
|
||||
writer.writerow(['CIF_File', 'Is_Only_Corner_Sharing'])
|
||||
# 写入所有结果
|
||||
writer.writerows(results)
|
||||
|
||||
print(f"\n处理完成!结果已保存到 '{output_csv_path}'")
|
||||
|
||||
except IOError as e:
|
||||
print(f"\n错误: 无法写入CSV文件 '{output_csv_path}': {e}")
|
||||
|
||||
|
||||
# --- 主程序入口 ---
|
||||
if __name__ == "__main__":
|
||||
# ----- 参数配置 -----
|
||||
# 请将此路径修改为您存放CIF文件的文件夹的实际路径
|
||||
CIF_DIRECTORY = "data/0921"
|
||||
|
||||
# 输出的CSV文件名
|
||||
OUTPUT_CSV = "corner_sharing_results.csv"
|
||||
# -------------------
|
||||
|
||||
# 调用主函数开始处理
|
||||
process_cif_folder(CIF_DIRECTORY, OUTPUT_CSV)
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import List, Dict
|
||||
|
||||
from pymatgen.core.structure import Structure
|
||||
from pymatgen.analysis.local_env import VoronoiNN
|
||||
import numpy as np
|
||||
@@ -24,7 +26,132 @@ def special_check_for_3(site, nearest):
|
||||
return real_nearest
|
||||
|
||||
|
||||
def CS_catulate(struct, sp='Li', anion=['O'], tol=0, cutoff=3.0,notice=False,ID=None):
|
||||
def CS_catulate(
|
||||
struct,
|
||||
sp: str = 'Li',
|
||||
anion: List[str] = ['O'],
|
||||
tol: float = 0,
|
||||
cutoff: float = 3.0,
|
||||
notice: bool = False
|
||||
) -> Dict[str, Dict[str, int]]:
|
||||
"""
|
||||
计算结构中不同类型阳离子多面体之间的共享关系(角、边、面共享)。
|
||||
|
||||
该函数会分别计算以下三种情况的共享数量:
|
||||
1. 目标原子 vs 目标原子 (e.g., Li-Li)
|
||||
2. 目标原子 vs 其他阳离子 (e.g., Li-X)
|
||||
3. 其他阳离子 vs 其他阳离子 (e.g., X-Y)
|
||||
|
||||
参数:
|
||||
struct (Structure): 输入的pymatgen结构对象。
|
||||
sp (str): 目标元素符号,默认为 'Li'。
|
||||
anion (list): 阴离子元素符号列表,默认为 ['O']。
|
||||
tol (float): VoronoiNN 的容差。对于Li,通常设为0。
|
||||
cutoff (float): VoronoiNN 的截断距离。对于Li,通常设为3.0。
|
||||
notice (bool): 是否打印详细的共享信息。
|
||||
|
||||
返回:
|
||||
dict: 一个字典,包含三类共享关系的统计结果。
|
||||
键 "sp_vs_sp", "sp_vs_other", "other_vs_other" 分别对应上述三种情况。
|
||||
每个键的值是另一个字典,统计了共享2个(边)、3个(面)等情况的数量。
|
||||
例如: {'sp_vs_sp': {'1': 10, '2': 4}, 'sp_vs_other': ...}
|
||||
共享1个阴离子为角共享,2个为边共享,3个为面共享。
|
||||
"""
|
||||
# 初始化 VoronoiNN 对象
|
||||
voro_nn = VoronoiNN(tol=tol, cutoff=cutoff)
|
||||
|
||||
# 1. 分类存储所有阳离子的近邻阴离子信息
|
||||
target_sites_info = []
|
||||
other_cation_sites_info = []
|
||||
|
||||
for index, site in enumerate(struct.sites):
|
||||
# 跳过阴离子本身
|
||||
if site.species.chemical_system in anion:
|
||||
continue
|
||||
|
||||
# 获取当前位点的近邻阴离子
|
||||
try:
|
||||
# 使用 get_nn_info 更直接
|
||||
nn_info = voro_nn.get_nn_info(struct, index)
|
||||
nearest_anions = [
|
||||
nn["site"] for nn in nn_info
|
||||
if nn["site"].species.chemical_system in anion
|
||||
]
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not get neighbors for site {index} ({site.species_string}): {e}")
|
||||
continue
|
||||
|
||||
if not nearest_anions:
|
||||
continue
|
||||
|
||||
# 整理信息
|
||||
site_info = {
|
||||
'index': index,
|
||||
'element': site.species.chemical_system,
|
||||
'nearest_anion_indices': {nn.index for nn in nearest_anions}
|
||||
}
|
||||
|
||||
# 根据是否为目标原子进行分类
|
||||
if site.species.chemical_system == sp:
|
||||
target_sites_info.append(site_info)
|
||||
else:
|
||||
other_cation_sites_info.append(site_info)
|
||||
|
||||
# 2. 初始化结果字典
|
||||
# 共享数量key: 1-角, 2-边, 3-面
|
||||
results = {
|
||||
"sp_vs_sp": {"1": 0, "2": 0, "3": 0, "4": 0},
|
||||
"sp_vs_other": {"1": 0, "2": 0, "3": 0, "4": 0},
|
||||
"other_vs_other": {"1": 0, "2": 0, "3": 0, "4": 0},
|
||||
}
|
||||
|
||||
# 3. 计算不同类别之间的共享关系
|
||||
|
||||
# 3.1 目标原子 vs 目标原子 (sp_vs_sp)
|
||||
for i in range(len(target_sites_info)):
|
||||
for j in range(i + 1, len(target_sites_info)):
|
||||
atom_i = target_sites_info[i]
|
||||
atom_j = target_sites_info[j]
|
||||
|
||||
shared_anions = atom_i['nearest_anion_indices'].intersection(atom_j['nearest_anion_indices'])
|
||||
shared_count = len(shared_anions)
|
||||
|
||||
if shared_count > 0 and str(shared_count) in results["sp_vs_sp"]:
|
||||
results["sp_vs_sp"][str(shared_count)] += 1
|
||||
if notice:
|
||||
print(
|
||||
f"[Li-Li] Atom {atom_i['index']} and {atom_j['index']} share {shared_count} anions: {shared_anions}")
|
||||
|
||||
# 3.2 目标原子 vs 其他阳离子 (sp_vs_other)
|
||||
for atom_sp in target_sites_info:
|
||||
for atom_other in other_cation_sites_info:
|
||||
shared_anions = atom_sp['nearest_anion_indices'].intersection(atom_other['nearest_anion_indices'])
|
||||
shared_count = len(shared_anions)
|
||||
|
||||
if shared_count > 0 and str(shared_count) in results["sp_vs_other"]:
|
||||
results["sp_vs_other"][str(shared_count)] += 1
|
||||
if notice:
|
||||
print(
|
||||
f"[Li-Other] Atom {atom_sp['index']} and {atom_other['index']} share {shared_count} anions: {shared_anions}")
|
||||
|
||||
# 3.3 其他阳离子 vs 其他阳离子 (other_vs_other)
|
||||
for i in range(len(other_cation_sites_info)):
|
||||
for j in range(i + 1, len(other_cation_sites_info)):
|
||||
atom_i = other_cation_sites_info[i]
|
||||
atom_j = other_cation_sites_info[j]
|
||||
|
||||
shared_anions = atom_i['nearest_anion_indices'].intersection(atom_j['nearest_anion_indices'])
|
||||
shared_count = len(shared_anions)
|
||||
|
||||
if shared_count > 0 and str(shared_count) in results["other_vs_other"]:
|
||||
results["other_vs_other"][str(shared_count)] += 1
|
||||
if notice:
|
||||
print(
|
||||
f"[Other-Other] Atom {atom_i['index']} and {atom_j['index']} share {shared_count} anions: {shared_anions}")
|
||||
|
||||
return results
|
||||
|
||||
def CS_catulate_old(struct, sp='Li', anion=['O'], tol=0, cutoff=3.0,notice=False,ID=None):
|
||||
"""
|
||||
计算结构中目标元素与最近阴离子的共享关系。
|
||||
|
||||
@@ -51,10 +178,10 @@ def CS_catulate(struct, sp='Li', anion=['O'], tol=0, cutoff=3.0,notice=False,ID=
|
||||
# 遍历结构中的每个位点
|
||||
for index,site in enumerate(struct.sites):
|
||||
# 跳过阴离子位点
|
||||
if site.specie.symbol in anion:
|
||||
if site.species.chemical_system in anion:
|
||||
continue
|
||||
# 跳过Li原子
|
||||
if site.specie.symbol == sp:
|
||||
if site.species.chemical_system == sp:
|
||||
continue
|
||||
# 获取 Voronoi 多面体信息
|
||||
voro_info = voro_nn.get_voronoi_polyhedra(struct, index)
|
||||
@@ -62,14 +189,14 @@ def CS_catulate(struct, sp='Li', anion=['O'], tol=0, cutoff=3.0,notice=False,ID=
|
||||
# 找到最近的阴离子位点
|
||||
nearest_anions = [
|
||||
nn_info["site"] for nn_info in voro_info.values()
|
||||
if nn_info["site"].specie.symbol in anion
|
||||
if nn_info["site"].species.chemical_system in anion
|
||||
]
|
||||
|
||||
# 如果没有找到最近的阴离子,跳过
|
||||
if not nearest_anions:
|
||||
print(f"No nearest anions found for {ID} site {index}.")
|
||||
continue
|
||||
if site.specie.symbol == 'B' or site.specie.symbol == 'N':
|
||||
if site.species.chemical_system == 'B' or site.species.chemical_system == 'N':
|
||||
nearest_anions = special_check_for_3(site,nearest_anions)
|
||||
nearest_anions = check_real(nearest_anions)
|
||||
# 将结果添加到 atom_dice 列表中
|
||||
@@ -110,10 +237,62 @@ def CS_catulate(struct, sp='Li', anion=['O'], tol=0, cutoff=3.0,notice=False,ID=
|
||||
return shared_count
|
||||
|
||||
|
||||
def CS_count(struct, shared_count, sp='Li'):
|
||||
def CS_count(struct, sharing_results: Dict[str, Dict[str, int]], sp: str = 'Li') -> float:
|
||||
"""
|
||||
分析多面体共享结果,计算平均每个目标原子参与的共享阴离子数。
|
||||
|
||||
这个函数是 calculate_polyhedra_sharing 的配套函数。
|
||||
|
||||
参数:
|
||||
struct (Structure): 输入的pymatgen结构对象,用于统计目标原子总数。
|
||||
sharing_results (dict): 来自 calculate_polyhedra_sharing 函数的输出结果。
|
||||
sp (str): 目标元素符号,默认为 'Li'。
|
||||
|
||||
返回:
|
||||
float: 平均每个目标原子sp参与的共享阴离子数量。
|
||||
例如,结果为2.5意味着平均每个Li原子通过共享与其他阳离子
|
||||
(包括Li和其他阳离子)连接了2.5个阴离子。
|
||||
"""
|
||||
# 1. 统计结构中目标原子的总数
|
||||
target_atom_count = 0
|
||||
for site in struct.sites:
|
||||
if site.species.chemical_system == sp:
|
||||
target_atom_count += 1
|
||||
|
||||
# 如果结构中没有目标原子,直接返回0,避免除以零错误
|
||||
if target_atom_count == 0:
|
||||
return 0.0
|
||||
|
||||
# 2. 计算加权的共享阴离子总数
|
||||
total_shared_anions = 0
|
||||
|
||||
# 处理 sp_vs_sp (例如 Li-Li) 的共享
|
||||
# 每个共享关系涉及两个目标原子,所以权重需要乘以 2
|
||||
if "sp_vs_sp" in sharing_results:
|
||||
sp_vs_sp_counts = sharing_results["sp_vs_sp"]
|
||||
for num_shared_str, count in sp_vs_sp_counts.items():
|
||||
num_shared = int(num_shared_str)
|
||||
# 权重 = 共享阴离子数 * 涉及的目标原子数 (2) * 出现次数
|
||||
total_shared_anions += num_shared * 2 * count
|
||||
|
||||
# 处理 sp_vs_other (例如 Li-X) 的共享
|
||||
# 每个共享关系涉及一个目标原子,所以权重乘以 1
|
||||
if "sp_vs_other" in sharing_results:
|
||||
sp_vs_other_counts = sharing_results["sp_vs_other"]
|
||||
for num_shared_str, count in sp_vs_other_counts.items():
|
||||
num_shared = int(num_shared_str)
|
||||
# 权重 = 共享阴离子数 * 涉及的目标原子数 (1) * 出现次数
|
||||
total_shared_anions += num_shared * 1 * count
|
||||
|
||||
# 3. 计算平均值
|
||||
# 平均每个目标原子参与的共享阴离子数 = 总的加权共享数 / 目标原子总数
|
||||
average_sharing_per_atom = total_shared_anions / target_atom_count
|
||||
|
||||
return average_sharing_per_atom
|
||||
def CS_count_old(struct, shared_count, sp='Li'):
|
||||
count = 0
|
||||
for site in struct.sites:
|
||||
if site.specie.symbol == sp:
|
||||
if site.species.chemical_system == sp:
|
||||
count += 1 # 累加符合条件的原子数量
|
||||
|
||||
CS_count = 0
|
||||
@@ -128,7 +307,50 @@ def CS_count(struct, shared_count, sp='Li'):
|
||||
|
||||
return CS_count
|
||||
|
||||
structure = Structure.from_file("../data/0921/wjy_001.cif")
|
||||
a = CS_catulate(structure,notice=True)
|
||||
b = CS_count(structure,a)
|
||||
print(f"{a}\n{b}")
|
||||
|
||||
def check_only_corner_sharing(sharing_results: Dict[str, Dict[str, int]]) -> int:
|
||||
"""
|
||||
检查目标原子(sp)是否只参与了角共享(共享1个阴离子)。
|
||||
|
||||
该函数是 calculate_polyhedra_sharing 的配套函数。
|
||||
|
||||
参数:
|
||||
sharing_results (dict): 来自 calculate_polyhedra_sharing 函数的输出结果。
|
||||
|
||||
返回:
|
||||
int:
|
||||
- 1: 如果 sp 的共享关系中,边共享(2)、面共享(3)等数量均为0,
|
||||
并且至少存在一个角共享(1)。
|
||||
- 0: 如果 sp 存在任何边、面等共享,或者没有任何共享关系。
|
||||
"""
|
||||
# 提取与目标原子 sp 相关的共享数据
|
||||
sp_vs_sp_counts = sharing_results.get("sp_vs_sp", {})
|
||||
sp_vs_other_counts = sharing_results.get("sp_vs_other", {})
|
||||
|
||||
# 1. 检查是否存在任何边共享、面共享等 (共享数 > 1)
|
||||
# 检查 sp-sp 的共享
|
||||
for num_shared_str, count in sp_vs_sp_counts.items():
|
||||
if int(num_shared_str) > 1 and count > 0:
|
||||
return 0 # 发现了边/面共享,立即返回 0
|
||||
|
||||
# 检查 sp-other 的共享
|
||||
for num_shared_str, count in sp_vs_other_counts.items():
|
||||
if int(num_shared_str) > 1 and count > 0:
|
||||
return 0 # 发现了边/面共享,立即返回 0
|
||||
|
||||
# 2. 检查是否存在至少一个角共享 (共享数 == 1)
|
||||
# 运行到这里,说明已经没有任何边/面共享了。
|
||||
# 现在需要确认是否真的存在角共享,而不是完全没有共享。
|
||||
corner_share_sp_sp = sp_vs_sp_counts.get("1", 0) > 0
|
||||
corner_share_sp_other = sp_vs_other_counts.get("1", 0) > 0
|
||||
|
||||
if corner_share_sp_sp or corner_share_sp_other:
|
||||
return 1 # 确认只存在角共享
|
||||
else:
|
||||
return 0 # 没有任何共享关系,也返回 0
|
||||
|
||||
# structure = Structure.from_file("../data/0921/wjy_001.cif")
|
||||
# a = CS_catulate(structure,notice=True)
|
||||
# b = CS_count(structure,a)
|
||||
# print(f"{a}\n{b}")
|
||||
# print(check_only_corner_sharing(a))
|
||||
50
dpgen/data/P3ma/model1.cif
Normal file
50
dpgen/data/P3ma/model1.cif
Normal file
@@ -0,0 +1,50 @@
|
||||
#==============================================================================
|
||||
# CRYSTAL DATA
|
||||
#------------------------------------------------------------------------------
|
||||
# Created from data in Table S9 for ball milled and 1h annealed Li3YCl6
|
||||
# Source: Image provided by user
|
||||
#==============================================================================
|
||||
|
||||
data_Li3YCl6_annealed
|
||||
_audit_creation_method 'Generated from published data table'
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
# CELL PARAMETERS
|
||||
#------------------------------------------------------------------------------
|
||||
_cell_length_a 11.2001(2)
|
||||
_cell_length_b 11.2001(2)
|
||||
_cell_length_c 6.0441(2)
|
||||
_cell_angle_alpha 90.0
|
||||
_cell_angle_beta 90.0
|
||||
_cell_angle_gamma 120.0
|
||||
_cell_volume 656.39
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
# SYMMETRY
|
||||
#------------------------------------------------------------------------------
|
||||
_symmetry_space_group_name_H-M 'P -3 m 1'
|
||||
_symmetry_Int_Tables_number 164
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
# ATOMIC COORDINATES AND DISPLACEMENT PARAMETERS
|
||||
#------------------------------------------------------------------------------
|
||||
loop_
|
||||
_atom_site_label
|
||||
_atom_site_type_symbol
|
||||
_atom_site_Wyckoff_symbol
|
||||
_atom_site_fract_x
|
||||
_atom_site_fract_y
|
||||
_atom_site_fract_z
|
||||
_atom_site_occupancy
|
||||
_atom_site_B_iso_or_equiv
|
||||
Y1 Y 1a 0.0000 0.0000 0.0000 1.0000 1.10(3)
|
||||
Y2 Y 2d 0.3333 0.6666 -0.065(3) 1.0000 1.10(3)
|
||||
Cl1 Cl 6i 0.1131(7) -0.1131(7) 0.7717(8) 1.0000 1.99(3)
|
||||
Cl2 Cl 6i 0.2182(7) -0.2182(7) 0.2606(8) 1.0000 1.99(3)
|
||||
Cl3 Cl 6i 0.4436(4) -0.4436(4) 0.7604(8) 1.0000 1.99(3)
|
||||
Li1 Li 6g 0.3397 0.3397 0.0000 1.0000 5.00
|
||||
Li2 Li 6h 0.3397 0.3397 0.5000 0.5000 5.00
|
||||
|
||||
#==============================================================================
|
||||
# END OF DATA
|
||||
#==============================================================================
|
||||
50
dpgen/data/P3ma/model2.cif
Normal file
50
dpgen/data/P3ma/model2.cif
Normal file
@@ -0,0 +1,50 @@
|
||||
#==============================================================================
|
||||
# CRYSTAL DATA
|
||||
#------------------------------------------------------------------------------
|
||||
# Created from data in Table S9 for ball milled and 1h annealed Li3YCl6
|
||||
# Source: Image provided by user
|
||||
#==============================================================================
|
||||
|
||||
data_Li3YCl6_annealed
|
||||
_audit_creation_method 'Generated from published data table'
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
# CELL PARAMETERS
|
||||
#------------------------------------------------------------------------------
|
||||
_cell_length_a 11.2001(2)
|
||||
_cell_length_b 11.2001(2)
|
||||
_cell_length_c 6.0441(2)
|
||||
_cell_angle_alpha 90.0
|
||||
_cell_angle_beta 90.0
|
||||
_cell_angle_gamma 120.0
|
||||
_cell_volume 656.39
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
# SYMMETRY
|
||||
#------------------------------------------------------------------------------
|
||||
_symmetry_space_group_name_H-M 'P -3 m 1'
|
||||
_symmetry_Int_Tables_number 164
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
# ATOMIC COORDINATES AND DISPLACEMENT PARAMETERS
|
||||
#------------------------------------------------------------------------------
|
||||
loop_
|
||||
_atom_site_label
|
||||
_atom_site_type_symbol
|
||||
_atom_site_Wyckoff_symbol
|
||||
_atom_site_fract_x
|
||||
_atom_site_fract_y
|
||||
_atom_site_fract_z
|
||||
_atom_site_occupancy
|
||||
_atom_site_B_iso_or_equiv
|
||||
Y1 Y 1a 0.0000 0.0000 0.0000 1.0000 1.10(3)
|
||||
Y2 Y 2d 0.3333 0.6666 0.488(1) 1.0000 1.10(3)
|
||||
Cl1 Cl 6i 0.1131(7) -0.1131(7) 0.7717(8) 1.0000 1.99(3)
|
||||
Cl2 Cl 6i 0.2182(7) -0.2182(7) 0.2606(8) 1.0000 1.99(3)
|
||||
Cl3 Cl 6i 0.4436(4) -0.4436(4) 0.7604(8) 1.0000 1.99(3)
|
||||
Li1 Li 6g 0.3397 0.3397 0.0000 1.0000 5.00
|
||||
Li2 Li 6h 0.3397 0.3397 0.5000 0.5000 5.00
|
||||
|
||||
#==============================================================================
|
||||
# END OF DATA
|
||||
#==============================================================================
|
||||
51
dpgen/data/P3ma/model3.cif
Normal file
51
dpgen/data/P3ma/model3.cif
Normal file
@@ -0,0 +1,51 @@
|
||||
#==============================================================================
|
||||
# CRYSTAL DATA
|
||||
#------------------------------------------------------------------------------
|
||||
# Created from data in Table S9 for ball milled and 1h annealed Li3YCl6
|
||||
# Source: Image provided by user
|
||||
#==============================================================================
|
||||
|
||||
data_Li3YCl6_annealed
|
||||
_audit_creation_method 'Generated from published data table'
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
# CELL PARAMETERS
|
||||
#------------------------------------------------------------------------------
|
||||
_cell_length_a 11.2001(2)
|
||||
_cell_length_b 11.2001(2)
|
||||
_cell_length_c 6.0441(2)
|
||||
_cell_angle_alpha 90.0
|
||||
_cell_angle_beta 90.0
|
||||
_cell_angle_gamma 120.0
|
||||
_cell_volume 656.39
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
# SYMMETRY
|
||||
#------------------------------------------------------------------------------
|
||||
_symmetry_space_group_name_H-M 'P -3 m 1'
|
||||
_symmetry_Int_Tables_number 164
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
# ATOMIC COORDINATES AND DISPLACEMENT PARAMETERS
|
||||
#------------------------------------------------------------------------------
|
||||
loop_
|
||||
_atom_site_label
|
||||
_atom_site_type_symbol
|
||||
_atom_site_Wyckoff_symbol
|
||||
_atom_site_fract_x
|
||||
_atom_site_fract_y
|
||||
_atom_site_fract_z
|
||||
_atom_site_occupancy
|
||||
_atom_site_B_iso_or_equiv
|
||||
Y1 Y 1a 0.0000 0.0000 0.0000 1.0000 1.10(3)
|
||||
Y2 Y 2d 0.3333 0.6666 0.488(1) 0.7500 1.10(3)
|
||||
Y3 Y 2d 0.3333 0.6666 -0.065(3) 0.2500 1.10(3)
|
||||
Cl1 Cl 6i 0.1131(7) -0.1131(7) 0.7717(8) 1.0000 1.99(3)
|
||||
Cl2 Cl 6i 0.2182(7) -0.2182(7) 0.2606(8) 1.0000 1.99(3)
|
||||
Cl3 Cl 6i 0.4436(4) -0.4436(4) 0.7604(8) 1.0000 1.99(3)
|
||||
Li1 Li 6g 0.3397 0.3397 0.0000 1.0000 5.00
|
||||
Li2 Li 6h 0.3397 0.3397 0.5000 0.5000 5.00
|
||||
|
||||
#==============================================================================
|
||||
# END OF DATA
|
||||
#==============================================================================
|
||||
60
dpgen/data/P3ma/model4.cif
Normal file
60
dpgen/data/P3ma/model4.cif
Normal file
@@ -0,0 +1,60 @@
|
||||
#==============================================================================
|
||||
# CRYSTAL DATA
|
||||
#------------------------------------------------------------------------------
|
||||
# Created from data in Table S9 for ball milled and 1h annealed Li3YCl6
|
||||
# Source: Image provided by user
|
||||
#==============================================================================
|
||||
|
||||
data_Li3YCl6_annealed
|
||||
_audit_creation_method 'Generated from published data table'
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
# CELL PARAMETERS
|
||||
#------------------------------------------------------------------------------
|
||||
_cell_length_a 11.2001(2)
|
||||
_cell_length_b 11.2001(2)
|
||||
_cell_length_c 6.0441(2)
|
||||
_cell_angle_alpha 90.0
|
||||
_cell_angle_beta 90.0
|
||||
_cell_angle_gamma 120.0
|
||||
_cell_volume 656.39
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
# SYMMETRY
|
||||
#------------------------------------------------------------------------------
|
||||
_symmetry_space_group_name_H-M 'P -3 m 1'
|
||||
_symmetry_Int_Tables_number 164
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
# ATOMIC COORDINATES AND DISPLACEMENT PARAMETERS
|
||||
#------------------------------------------------------------------------------
|
||||
loop_
|
||||
_atom_site_label
|
||||
_atom_site_type_symbol
|
||||
_atom_site_fract_x
|
||||
_atom_site_fract_y
|
||||
_atom_site_fract_z
|
||||
_atom_site_occupancy
|
||||
|
||||
# --- Y位点 (基于Model 3, Occ: M1=1, M2=0.75, M3=0.25) ---
|
||||
Y1_main Y 0.0000 0.0000 0.0000 0.9444 # 1a位
|
||||
Li_on_Y1 Li 0.0000 0.0000 0.0000 0.0556
|
||||
Y2_main Y 0.3333 0.6666 0.4880 0.7083 # 2d位
|
||||
Li_on_Y2 Li 0.3333 0.6666 0.4880 0.0417
|
||||
Y3_main Y 0.3333 0.6666 -0.0650 0.2361 # 2d位
|
||||
Li_on_Y3 Li 0.3333 0.6666 -0.0650 0.0139
|
||||
|
||||
# --- Li位点 (基于标准结构, Occ: 6g=1, 6h=0.5) ---
|
||||
Li1_main Li 0.3397 0.3397 0.0000 0.9815 # 6g位
|
||||
Y_on_Li1 Y 0.3397 0.3397 0.0000 0.0185
|
||||
Li2_main Li 0.3397 0.3397 0.5000 0.4907 # 6h位
|
||||
Y_on_Li2 Y 0.3397 0.3397 0.5000 0.0093
|
||||
|
||||
# --- Cl位点 (保持不变) ---
|
||||
Cl1 Cl 0.1131 -0.1131 0.7717 1.0000
|
||||
Cl2 Cl 0.2182 -0.2182 0.2606 1.0000
|
||||
Cl3 Cl 0.4436 -0.4436 0.7604 1.0000
|
||||
|
||||
#==============================================================================
|
||||
# END OF DATA
|
||||
#==============================================================================
|
||||
51
dpgen/data/P3ma/origin.cif
Normal file
51
dpgen/data/P3ma/origin.cif
Normal file
@@ -0,0 +1,51 @@
|
||||
#==============================================================================
|
||||
# CRYSTAL DATA
|
||||
#------------------------------------------------------------------------------
|
||||
# Created from data in Table S9 for ball milled and 1h annealed Li3YCl6
|
||||
# Source: Image provided by user
|
||||
#==============================================================================
|
||||
|
||||
data_Li3YCl6_annealed
|
||||
_audit_creation_method 'Generated from published data table'
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
# CELL PARAMETERS
|
||||
#------------------------------------------------------------------------------
|
||||
_cell_length_a 11.2001(2)
|
||||
_cell_length_b 11.2001(2)
|
||||
_cell_length_c 6.0441(2)
|
||||
_cell_angle_alpha 90.0
|
||||
_cell_angle_beta 90.0
|
||||
_cell_angle_gamma 120.0
|
||||
_cell_volume 656.39
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
# SYMMETRY
|
||||
#------------------------------------------------------------------------------
|
||||
_symmetry_space_group_name_H-M 'P -3 m 1'
|
||||
_symmetry_Int_Tables_number 164
|
||||
|
||||
#------------------------------------------------------------------------------
|
||||
# ATOMIC COORDINATES AND DISPLACEMENT PARAMETERS
|
||||
#------------------------------------------------------------------------------
|
||||
loop_
|
||||
_atom_site_label
|
||||
_atom_site_type_symbol
|
||||
_atom_site_Wyckoff_symbol
|
||||
_atom_site_fract_x
|
||||
_atom_site_fract_y
|
||||
_atom_site_fract_z
|
||||
_atom_site_occupancy
|
||||
_atom_site_B_iso_or_equiv
|
||||
Y1 Y 1a 0.0000 0.0000 0.0000 1.0000 1.10(3)
|
||||
Y2 Y 2d 0.3333 0.6666 0.488(1) 0.8269 1.10(3)
|
||||
Y3 Y 2d 0.3333 0.6666 -0.065(3) 0.1730 1.10(3)
|
||||
Cl1 Cl 6i 0.1131(7) -0.1131(7) 0.7717(8) 1.0000 1.99(3)
|
||||
Cl2 Cl 6i 0.2182(7) -0.2182(7) 0.2606(8) 1.0000 1.99(3)
|
||||
Cl3 Cl 6i 0.4436(4) -0.4436(4) 0.7604(8) 1.0000 1.99(3)
|
||||
Li1 Li 6g 0.3397 0.3397 0.0000 1.0000 5.00
|
||||
Li2 Li 6h 0.3397 0.3397 0.5000 0.5000 5.00
|
||||
|
||||
#==============================================================================
|
||||
# END OF DATA
|
||||
#==============================================================================
|
||||
53
dpgen/data/Pnma/origin_backup.cif
Normal file
53
dpgen/data/Pnma/origin_backup.cif
Normal file
@@ -0,0 +1,53 @@
|
||||
#------------------------------------------------------------------------------
|
||||
# CIF (Crystallographic Information File) for Li3YCl6
|
||||
# Data source: Table S1 from the provided image.
|
||||
# Rietveld refinement result of the neutron diffraction pattern for the 450 °C-annealed sample.
|
||||
#------------------------------------------------------------------------------
|
||||
|
||||
data_Li3YCl6
|
||||
|
||||
_chemical_name_systematic 'Lithium Yttrium Chloride'
|
||||
_chemical_formula_sum 'Li3 Y1 Cl6'
|
||||
_chemical_formula_structural 'Li3YCl6'
|
||||
|
||||
_symmetry_space_group_name_H-M 'P n m a'
|
||||
_symmetry_Int_Tables_number 62
|
||||
_symmetry_cell_setting orthorhombic
|
||||
|
||||
loop_
|
||||
_symmetry_equiv_pos_as_xyz
|
||||
'x, y, z'
|
||||
'-x+1/2, y+1/2, -z+1/2'
|
||||
'-x, -y, -z'
|
||||
'x+1/2, -y+1/2, z+1/2'
|
||||
'-x, y+1/2, -z'
|
||||
'x-1/2, -y-1/2, z-1/2'
|
||||
'x, -y, z'
|
||||
'-x-1/2, y-1/2, -z-1/2'
|
||||
|
||||
_cell_length_a 12.92765(13)
|
||||
_cell_length_b 11.19444(10)
|
||||
_cell_length_c 6.04000(12)
|
||||
_cell_angle_alpha 90.0
|
||||
_cell_angle_beta 90.0
|
||||
_cell_angle_gamma 90.0
|
||||
_cell_volume 874.15
|
||||
_cell_formula_units_Z 4
|
||||
|
||||
loop_
|
||||
_atom_site_label
|
||||
_atom_site_type_symbol
|
||||
_atom_site_fract_x
|
||||
_atom_site_fract_y
|
||||
_atom_site_fract_z
|
||||
_atom_site_occupancy
|
||||
_atom_site_Wyckoff_symbol
|
||||
_atom_site_U_iso_or_equiv
|
||||
Li1 Li 0.11730(7) 0.09640(7) 0.04860(10) 0.750(13) 8d 4.579(2)
|
||||
Li2 Li 0.13270(9) 0.07900(10) 0.48600(2) 0.750(19) 8d 9.554(4)
|
||||
Cl1 Cl 0.21726(7) 0.58920(7) 0.26362(11) 1.0 8d 0.797(17)
|
||||
Cl2 Cl 0.45948(8) 0.08259(8) 0.23831(13) 1.0 8d 1.548(2)
|
||||
Cl3 Cl 0.04505(10) 0.25000 0.74110(2) 1.0 4c 1.848(3)
|
||||
Cl4 Cl 0.20205(9) 0.25000 0.24970(2) 1.0 4c 0.561(2)
|
||||
Y1 Y 0.37529(10) 0.25000 0.01870(3) 1.0 4c 1.121(17)
|
||||
#------------------------------------------------------------------------------
|
||||
151
dpgen/plus.py
Normal file
151
dpgen/plus.py
Normal file
@@ -0,0 +1,151 @@
|
||||
import random
|
||||
from typing import List
|
||||
from pymatgen.core import Structure
|
||||
from pymatgen.io.vasp import Poscar
|
||||
|
||||
def _is_close_frac(z, target, tol=2e-2):
|
||||
t = target % 1.0
|
||||
return min(abs(z - t), abs(z - (t + 1)), abs(z - (t - 1))) < tol
|
||||
|
||||
def make_model3_poscar_from_cif(cif_path: str,
|
||||
out_poscar: str = "POSCAR_model3_supercell",
|
||||
seed: int = 42,
|
||||
tol: float = 2e-2):
|
||||
"""
|
||||
将 model3.cif 扩胞为 [[3,0,0],[2,4,0],[0,0,6]] 的2160原子超胞,并把部分占据位点(Y2=0.75, Y3=0.25, Li2=0.5)
|
||||
显式有序化后写出 POSCAR。
|
||||
"""
|
||||
random.seed(seed)
|
||||
|
||||
# 1) 读取 CIF
|
||||
s = Structure.from_file(cif_path)
|
||||
|
||||
# 2) 扩胞(a_s=3a0, b_s=2a0+4b0, c_s=6c0)[1]
|
||||
T = [[3, 0, 0],
|
||||
[2, 4, 0],
|
||||
[0, 0, 6]]
|
||||
s.make_supercell(T)
|
||||
|
||||
# 3) 识别三类需取整的位点:Y2、Y3、Li2
|
||||
y2_idx: List[int] = []
|
||||
y3_idx: List[int] = []
|
||||
li2_idx: List[int] = []
|
||||
|
||||
for i, site in enumerate(s.sites):
|
||||
# 兼容不同版本pymatgen
|
||||
try:
|
||||
el = site.species.elements[0].symbol
|
||||
except Exception:
|
||||
ss = site.species_string
|
||||
el = "Li" if ss.startswith("Li") else ("Y" if ss.startswith("Y") else ("Cl" if ss.startswith("Cl") else ss))
|
||||
z = site.frac_coords[2]
|
||||
if el == "Y":
|
||||
if _is_close_frac(z, 0.488, tol):
|
||||
y2_idx.append(i)
|
||||
elif _is_close_frac(z, -0.065, tol) or _is_close_frac(z, 0.935, tol):
|
||||
y3_idx.append(i)
|
||||
elif el == "Li":
|
||||
if _is_close_frac(z, 0.5, tol):
|
||||
li2_idx.append(i)
|
||||
|
||||
def choose_keep(idxs, frac_keep):
|
||||
n = len(idxs)
|
||||
k = int(round(n * frac_keep))
|
||||
if k < 0: k = 0
|
||||
if k > n: k = n
|
||||
keep = set(random.sample(idxs, k)) if 0 < k < n else set(idxs if k == n else [])
|
||||
drop = [i for i in idxs if i not in keep]
|
||||
return keep, drop
|
||||
|
||||
keep_y2, drop_y2 = choose_keep(y2_idx, 0.75)
|
||||
keep_y3, drop_y3 = choose_keep(y3_idx, 0.25)
|
||||
keep_li2, drop_li2 = choose_keep(li2_idx, 0.50)
|
||||
|
||||
# 4) 保留者占据设为1,其余删除
|
||||
for i in keep_y2 | keep_y3:
|
||||
s.replace(i, "Y")
|
||||
for i in keep_li2:
|
||||
s.replace(i, "Li")
|
||||
to_remove = sorted(drop_y2 + drop_y3 + drop_li2, reverse=True)
|
||||
for i in to_remove:
|
||||
s.remove_sites([i])
|
||||
|
||||
# 5) 最终清理:消除任何残留的部分占据(防止 POSCAR 写出报错)
|
||||
# 若有 site.is_ordered==False,则取该站位的“主要元素”替换为占据=1
|
||||
for i, site in enumerate(s.sites):
|
||||
if not site.is_ordered:
|
||||
d = site.species.as_dict() # {'Li': 0.5} 或 {'Li':0.5,'Y':0.5}
|
||||
elem = max(d.items(), key=lambda kv: kv[1])[0]
|
||||
s.replace(i, elem)
|
||||
|
||||
# 6) 排序并写出 POSCAR
|
||||
order = {"Li": 0, "Y": 1, "Cl": 2}
|
||||
s = s.get_sorted_structure(key=lambda site: order.get(site.species.elements[0].symbol, 99))
|
||||
Poscar(s).write_file(out_poscar)
|
||||
|
||||
# 报告
|
||||
comp = {k: int(v) for k, v in s.composition.as_dict().items()}
|
||||
print(f"写出 {out_poscar};总原子数 = {len(s)}")
|
||||
print(f"Y2识别={len(y2_idx)},Y3识别={len(y3_idx)},Li2识别={len(li2_idx)};组成={comp}")
|
||||
|
||||
import random
|
||||
from typing import List
|
||||
from pymatgen.core import Structure
|
||||
from pymatgen.io.vasp import Poscar
|
||||
|
||||
def make_pnma_poscar_from_cif(cif_path: str,
|
||||
out_poscar: str = "POSCAR_pnma_supercell",
|
||||
seed: int = 42,
|
||||
supercell=(3,3,6),
|
||||
tol: float = 1e-6):
|
||||
"""
|
||||
读取 Pnma 的 CIF(如 origin.cif),扩胞到 2160 原子,并把部分占据的 Li 位点(0.75)显式取整后写出 POSCAR。
|
||||
默认超胞尺度为(3,3,6),体积放大因子=54,40原子/原胞×54=2160 [1][3]。
|
||||
"""
|
||||
random.seed(seed)
|
||||
|
||||
s = Structure.from_file(cif_path)
|
||||
|
||||
# 扩胞;Pnma原胞已是正交,直接用对角放缩
|
||||
s.make_supercell(supercell)
|
||||
|
||||
# 找出所有“部分占据的 Li”位点
|
||||
partial_li_idx: List[int] = []
|
||||
for i, site in enumerate(s.sites):
|
||||
if not site.is_ordered:
|
||||
d = site.species.as_dict() # 例如 {'Li': 0.75}
|
||||
# 只处理主要元素是Li且占据<1的位点
|
||||
m_elem, m_occ = max(d.items(), key=lambda kv: kv[1])
|
||||
if m_elem == "Li" and m_occ < 1 - tol:
|
||||
partial_li_idx.append(i)
|
||||
|
||||
# 以占据0.75进行随机取整:保留75%,其余删除为“空位”
|
||||
n = len(partial_li_idx)
|
||||
k = int(round(n * 0.75))
|
||||
keep = set(random.sample(partial_li_idx, k)) if 0 < k < n else set(partial_li_idx if k == n else [])
|
||||
drop = sorted([i for i in partial_li_idx if i not in keep], reverse=True)
|
||||
|
||||
# 保留者设为占据=1;删除其余
|
||||
for i in keep:
|
||||
s.replace(i, "Li")
|
||||
for i in drop:
|
||||
s.remove_sites([i])
|
||||
|
||||
# 兜底:若仍有部分占据,强制取主要元素
|
||||
for i, site in enumerate(s.sites):
|
||||
if not site.is_ordered:
|
||||
d = site.species.as_dict()
|
||||
elem = max(d.items(), key=lambda kv: kv[1])[0]
|
||||
s.replace(i, elem)
|
||||
|
||||
# 排序并写POSCAR
|
||||
order = {"Li": 0, "Y": 1, "Cl": 2}
|
||||
s = s.get_sorted_structure(key=lambda site: order.get(site.species.elements[0].symbol, 99))
|
||||
Poscar(s).write_file(out_poscar)
|
||||
|
||||
comp = {k: int(v) for k, v in s.composition.as_dict().items()}
|
||||
print(f"写出 {out_poscar};总原子数 = {len(s)};组成 = {comp}")
|
||||
|
||||
if __name__=="__main__":
|
||||
# make_model3_poscar_from_cif("data/P3ma/model3.cif","data/P3ma/supercell_model4.poscar")
|
||||
make_pnma_poscar_from_cif("data/Pnma/origin.cif","data/Pnma/supercell_pnma.poscar",seed=42)
|
||||
0
dpgen/supercell_make_p3ma.py
Normal file
0
dpgen/supercell_make_p3ma.py
Normal file
76
dpgen/transport.py
Normal file
76
dpgen/transport.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from typing import List, Dict, Tuple
|
||||
|
||||
def add_li_y_antisite_average(
|
||||
x: float,
|
||||
# Y 位点:[(label, wyckoff, (fx,fy,fz), occ, multiplicity)]
|
||||
y_sites: List[Tuple[str, str, Tuple[float, float, float], float, int]] = None,
|
||||
# Li 位点:[(label, wyckoff, (fx,fy,fz), occ, multiplicity)]
|
||||
li_sites: List[Tuple[str, str, Tuple[float, float, float], float, int]] = None,
|
||||
# 可选:各位点的 B 因子(字典),不需要可留空
|
||||
b_iso: Dict[str, float] = None,
|
||||
# 保留小数位
|
||||
ndigits: int = 4
|
||||
) -> str:
|
||||
"""
|
||||
把结构改写成含 x (相对于 Y 总数) 的 Li/Y 反位缺陷的“平均占据”CIF 块。
|
||||
返回字符串,可直接粘贴到 CIF 的 atom_site loop 中。
|
||||
"""
|
||||
# 默认使用表 S9(P-3̅m1)的坐标与占据;multiplicity 依据该空间群
|
||||
if y_sites is None:
|
||||
y_sites = [
|
||||
("Y1", "1a", (0.0000, 0.0000, 0.0000), 1.0000, 1),
|
||||
("Y2", "2d", (0.3333, 0.6666, 0.4880), 0.8269, 2),
|
||||
("Y3", "2d", (0.3333, 0.6666,-0.0650), 0.1730, 2),
|
||||
]
|
||||
if li_sites is None:
|
||||
li_sites = [
|
||||
("Li1", "6g", (0.3397, 0.3397, 0.0000), 1.0000, 6),
|
||||
("Li2", "6h", (0.3397, 0.3397, 0.5000), 0.5000, 6),
|
||||
]
|
||||
if b_iso is None:
|
||||
# 可按需要改
|
||||
b_iso = {"Y": 1.10, "Li": 5.00, "Cl": 1.99}
|
||||
|
||||
# 1) 总 Y 与总 Li(每原胞的“原子数贡献”):
|
||||
y_total = sum(m * occ for _,_,_,occ,m in y_sites) # 期望为 3
|
||||
li_total = sum(m * occ for _,_,_,occ,m in li_sites) # 期望为 9
|
||||
|
||||
# 2) 需要交换的“Y 数量”(相对于每原胞)
|
||||
y_exchanged = x * y_total # 例如 x=0.0556 → ~0.1668 个 Y/原胞
|
||||
|
||||
# 3) 折算到 Li 子格的分配比例(保证电中性和计量)
|
||||
li_fraction = y_exchanged / li_total # 每个 Li 位点应引入的 Y 的比例
|
||||
|
||||
# 4) 生成 CIF loop 文本(同一坐标写两行,分别为主元素与反位元素)
|
||||
lines = []
|
||||
header = [
|
||||
"loop_",
|
||||
" _atom_site_label",
|
||||
" _atom_site_type_symbol",
|
||||
" _atom_site_fract_x",
|
||||
" _atom_site_fract_y",
|
||||
" _atom_site_fract_z",
|
||||
" _atom_site_occupancy",
|
||||
" _atom_site_B_iso_or_equiv",
|
||||
]
|
||||
lines.extend(header)
|
||||
|
||||
# 4a) 处理 Y 位点:Y → (1-x)*occ;Li_on_Y → x*occ
|
||||
for label, wyck, (fx,fy,fz), occ, mult in y_sites:
|
||||
y_main_occ = round(occ * (1.0 - x), ndigits)
|
||||
li_on_y_occ = round(occ * x, ndigits)
|
||||
lines.append(f" {label}_main Y {fx:.4f} {fy:.4f} {fz:.4f} {y_main_occ:.{ndigits}f} {b_iso.get('Y',1.0)}")
|
||||
lines.append(f" Li_on_{label} Li {fx:.4f} {fy:.4f} {fz:.4f} {li_on_y_occ:.{ndigits}f} {b_iso.get('Li',5.0)}")
|
||||
|
||||
# 4b) 处理 Li 位点:Li → (1-li_fraction)*occ;Y_on_Li → li_fraction*occ
|
||||
for label, wyck, (fx,fy,fz), occ, mult in li_sites:
|
||||
li_main_occ = round(occ * (1.0 - li_fraction), ndigits)
|
||||
y_on_li_occ = round(occ * li_fraction, ndigits)
|
||||
lines.append(f" {label}_main Li {fx:.4f} {fy:.4f} {fz:.4f} {li_main_occ:.{ndigits}f} {b_iso.get('Li',5.0)}")
|
||||
lines.append(f" Y_on_{label} Y {fx:.4f} {fy:.4f} {fz:.4f} {y_on_li_occ:.{ndigits}f} {b_iso.get('Y',1.0)}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
# === 示例 ===
|
||||
# x=0.0556(与论文使用的 5.56% 一致,定义为相对于 Y 总数的比例):
|
||||
print(add_li_y_antisite_average(0.0556))
|
||||
21
mcp/SearchPaperByEmbedding/LICENSE
Normal file
21
mcp/SearchPaperByEmbedding/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2025 gyj155
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
96
mcp/SearchPaperByEmbedding/README.md
Normal file
96
mcp/SearchPaperByEmbedding/README.md
Normal file
@@ -0,0 +1,96 @@
|
||||
# Paper Semantic Search
|
||||
|
||||
Find similar papers using semantic search. Supports both local models (free) and OpenAI API (better quality).
|
||||
|
||||
## Features
|
||||
|
||||
- Request for papers from OpenReview (e.g., ICLR2026 submissions)
|
||||
- Semantic search with example papers or text queries
|
||||
- Support embedding caching
|
||||
- Embed model support: Open-source (e.g., all-MiniLM-L6-v2) or OpenAI
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### 1. Prepare Papers
|
||||
|
||||
```python
|
||||
from crawl import crawl_papers
|
||||
|
||||
crawl_papers(
|
||||
venue_id="ICLR.cc/2026/Conference/Submission",
|
||||
output_file="iclr2026_papers.json"
|
||||
)
|
||||
```
|
||||
|
||||
### 2. Search Papers
|
||||
|
||||
```python
|
||||
from search import PaperSearcher
|
||||
|
||||
# Local model (free)
|
||||
searcher = PaperSearcher('iclr2026_papers.json', model_type='local')
|
||||
|
||||
# OpenAI model (better, requires API key)
|
||||
# export OPENAI_API_KEY='your-key'
|
||||
# searcher = PaperSearcher('iclr2026_papers.json', model_type='openai')
|
||||
|
||||
searcher.compute_embeddings()
|
||||
|
||||
# Search with example papers that you are interested in
|
||||
examples = [
|
||||
{
|
||||
"title": "Your paper title",
|
||||
"abstract": "Your paper abstract..."
|
||||
}
|
||||
]
|
||||
|
||||
results = searcher.search(examples=examples, top_k=100)
|
||||
|
||||
# Or search with text query
|
||||
results = searcher.search(query="interesting topics", top_k=100)
|
||||
|
||||
searcher.display(results, n=10)
|
||||
searcher.save(results, 'results.json')
|
||||
```
|
||||
|
||||
|
||||
|
||||
## How It Works
|
||||
|
||||
1. Paper titles and abstracts are converted to embeddings
|
||||
2. Embeddings are cached automatically
|
||||
3. Your query is embedded using the same model
|
||||
4. Cosine similarity finds the most similar papers
|
||||
5. Results are ranked by similarity score
|
||||
|
||||
## Cache
|
||||
|
||||
Embeddings are cached as `cache_<filename>_<hash>_<model>.npy`. Delete to recompute.
|
||||
|
||||
## Example Output
|
||||
|
||||
```
|
||||
================================================================================
|
||||
Top 100 Results (showing 10)
|
||||
================================================================================
|
||||
|
||||
1. [0.8456] Paper a
|
||||
#12345 | foundation or frontier models, including LLMs
|
||||
https://openreview.net/forum?id=xxx
|
||||
|
||||
2. [0.8234] Paper b
|
||||
#12346 | applications to robotics, autonomy, planning
|
||||
https://openreview.net/forum?id=yyy
|
||||
```
|
||||
|
||||
## Tips
|
||||
|
||||
- Use 1-5 example papers for best results, or a paragraph of description of your interested topic
|
||||
- Local model is good enough for most cases
|
||||
- OpenAI model for critical search (~$1 for 18k queries)
|
||||
|
||||
If it's useful, please consider giving a star~
|
||||
66
mcp/SearchPaperByEmbedding/crawl.py
Normal file
66
mcp/SearchPaperByEmbedding/crawl.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import requests
|
||||
import json
|
||||
import time
|
||||
|
||||
def fetch_submissions(venue_id, offset=0, limit=1000):
|
||||
url = "https://api2.openreview.net/notes"
|
||||
params = {
|
||||
"content.venueid": venue_id,
|
||||
"details": "replyCount,invitation",
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
"sort": "number:desc"
|
||||
}
|
||||
headers = {"User-Agent": "Mozilla/5.0"}
|
||||
response = requests.get(url, params=params, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def crawl_papers(venue_id, output_file):
|
||||
all_papers = []
|
||||
offset = 0
|
||||
limit = 1000
|
||||
|
||||
print(f"Fetching papers from {venue_id}...")
|
||||
|
||||
while True:
|
||||
data = fetch_submissions(venue_id, offset, limit)
|
||||
notes = data.get("notes", [])
|
||||
|
||||
if not notes:
|
||||
break
|
||||
|
||||
for note in notes:
|
||||
paper = {
|
||||
"id": note.get("id"),
|
||||
"number": note.get("number"),
|
||||
"title": note.get("content", {}).get("title", {}).get("value", ""),
|
||||
"authors": note.get("content", {}).get("authors", {}).get("value", []),
|
||||
"abstract": note.get("content", {}).get("abstract", {}).get("value", ""),
|
||||
"keywords": note.get("content", {}).get("keywords", {}).get("value", []),
|
||||
"primary_area": note.get("content", {}).get("primary_area", {}).get("value", ""),
|
||||
"forum_url": f"https://openreview.net/forum?id={note.get('id')}"
|
||||
}
|
||||
all_papers.append(paper)
|
||||
|
||||
print(f"Fetched {len(notes)} papers (total: {len(all_papers)})")
|
||||
|
||||
if len(notes) < limit:
|
||||
break
|
||||
|
||||
offset += limit
|
||||
time.sleep(0.5)
|
||||
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
json.dump(all_papers, f, ensure_ascii=False, indent=2)
|
||||
|
||||
print(f"\nTotal: {len(all_papers)} papers")
|
||||
print(f"Saved to {output_file}")
|
||||
return all_papers
|
||||
|
||||
if __name__ == "__main__":
|
||||
crawl_papers(
|
||||
venue_id="ICLR.cc/2026/Conference/Submission",
|
||||
output_file="iclr2026_papers.json"
|
||||
)
|
||||
|
||||
22
mcp/SearchPaperByEmbedding/demo.py
Normal file
22
mcp/SearchPaperByEmbedding/demo.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from search import PaperSearcher
|
||||
|
||||
# Use local model (free)
|
||||
searcher = PaperSearcher('iclr2026_papers.json', model_type='local')
|
||||
|
||||
# Or use OpenAI (better quality)
|
||||
# searcher = PaperSearcher('iclr2026_papers.json', model_type='openai')
|
||||
|
||||
searcher.compute_embeddings()
|
||||
|
||||
examples = [
|
||||
{
|
||||
"title": "Solid-State battery",
|
||||
"abstract": "Solid-State battery"
|
||||
},
|
||||
]
|
||||
|
||||
results = searcher.search(examples=examples, top_k=100)
|
||||
|
||||
searcher.display(results, n=10)
|
||||
searcher.save(results, 'results.json')
|
||||
|
||||
6
mcp/SearchPaperByEmbedding/requirements.txt
Normal file
6
mcp/SearchPaperByEmbedding/requirements.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
requests
|
||||
numpy
|
||||
scikit-learn
|
||||
sentence-transformers
|
||||
openai
|
||||
|
||||
156
mcp/SearchPaperByEmbedding/search.py
Normal file
156
mcp/SearchPaperByEmbedding/search.py
Normal file
@@ -0,0 +1,156 @@
|
||||
import json
|
||||
import numpy as np
|
||||
import os
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
|
||||
class PaperSearcher:
|
||||
def __init__(self, papers_file, model_type="openai", api_key=None, base_url=None):
|
||||
with open(papers_file, 'r', encoding='utf-8') as f:
|
||||
self.papers = json.load(f)
|
||||
|
||||
self.model_type = model_type
|
||||
self.cache_file = self._get_cache_file(papers_file, model_type)
|
||||
self.embeddings = None
|
||||
|
||||
if model_type == "openai":
|
||||
from openai import OpenAI
|
||||
self.client = OpenAI(
|
||||
api_key=api_key or os.getenv('OPENAI_API_KEY'),
|
||||
base_url=base_url
|
||||
)
|
||||
self.model_name = "text-embedding-3-large"
|
||||
else:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
self.model = SentenceTransformer('all-MiniLM-L6-v2')
|
||||
self.model_name = "all-MiniLM-L6-v2"
|
||||
|
||||
self._load_cache()
|
||||
|
||||
def _get_cache_file(self, papers_file, model_type):
|
||||
base_name = Path(papers_file).stem
|
||||
file_hash = hashlib.md5(papers_file.encode()).hexdigest()[:8]
|
||||
cache_name = f"cache_{base_name}_{file_hash}_{model_type}.npy"
|
||||
return str(Path(papers_file).parent / cache_name)
|
||||
|
||||
def _load_cache(self):
|
||||
if os.path.exists(self.cache_file):
|
||||
try:
|
||||
self.embeddings = np.load(self.cache_file)
|
||||
if len(self.embeddings) == len(self.papers):
|
||||
print(f"Loaded cache: {self.embeddings.shape}")
|
||||
return True
|
||||
self.embeddings = None
|
||||
except:
|
||||
self.embeddings = None
|
||||
return False
|
||||
|
||||
def _save_cache(self):
|
||||
np.save(self.cache_file, self.embeddings)
|
||||
print(f"Saved cache: {self.cache_file}")
|
||||
|
||||
def _create_text(self, paper):
|
||||
parts = []
|
||||
if paper.get('title'):
|
||||
parts.append(f"Title: {paper['title']}")
|
||||
if paper.get('abstract'):
|
||||
parts.append(f"Abstract: {paper['abstract']}")
|
||||
if paper.get('keywords'):
|
||||
kw = ', '.join(paper['keywords']) if isinstance(paper['keywords'], list) else paper['keywords']
|
||||
parts.append(f"Keywords: {kw}")
|
||||
return ' '.join(parts)
|
||||
|
||||
def _embed_openai(self, texts):
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
|
||||
embeddings = []
|
||||
batch_size = 100
|
||||
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i:i + batch_size]
|
||||
response = self.client.embeddings.create(input=batch, model=self.model_name)
|
||||
embeddings.extend([item.embedding for item in response.data])
|
||||
|
||||
return np.array(embeddings)
|
||||
|
||||
def _embed_local(self, texts):
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
return self.model.encode(texts, show_progress_bar=len(texts) > 100)
|
||||
|
||||
def compute_embeddings(self, force=False):
|
||||
if self.embeddings is not None and not force:
|
||||
print("Using cached embeddings")
|
||||
return self.embeddings
|
||||
|
||||
print(f"Computing embeddings ({self.model_name})...")
|
||||
texts = [self._create_text(p) for p in self.papers]
|
||||
|
||||
if self.model_type == "openai":
|
||||
self.embeddings = self._embed_openai(texts)
|
||||
else:
|
||||
self.embeddings = self._embed_local(texts)
|
||||
|
||||
print(f"Computed: {self.embeddings.shape}")
|
||||
self._save_cache()
|
||||
return self.embeddings
|
||||
|
||||
def search(self, examples=None, query=None, top_k=100):
|
||||
if self.embeddings is None:
|
||||
self.compute_embeddings()
|
||||
|
||||
if examples:
|
||||
texts = []
|
||||
for ex in examples:
|
||||
text = f"Title: {ex['title']}"
|
||||
if ex.get('abstract'):
|
||||
text += f" Abstract: {ex['abstract']}"
|
||||
texts.append(text)
|
||||
|
||||
if self.model_type == "openai":
|
||||
embs = self._embed_openai(texts)
|
||||
else:
|
||||
embs = self._embed_local(texts)
|
||||
|
||||
query_emb = np.mean(embs, axis=0).reshape(1, -1)
|
||||
|
||||
elif query:
|
||||
if self.model_type == "openai":
|
||||
query_emb = self._embed_openai(query).reshape(1, -1)
|
||||
else:
|
||||
query_emb = self._embed_local(query).reshape(1, -1)
|
||||
else:
|
||||
raise ValueError("Provide either examples or query")
|
||||
|
||||
similarities = cosine_similarity(query_emb, self.embeddings)[0]
|
||||
top_indices = np.argsort(similarities)[::-1][:top_k]
|
||||
|
||||
return [{
|
||||
'paper': self.papers[idx],
|
||||
'similarity': float(similarities[idx])
|
||||
} for idx in top_indices]
|
||||
|
||||
def display(self, results, n=10):
|
||||
print(f"\n{'='*80}")
|
||||
print(f"Top {len(results)} Results (showing {min(n, len(results))})")
|
||||
print(f"{'='*80}\n")
|
||||
|
||||
for i, result in enumerate(results[:n], 1):
|
||||
paper = result['paper']
|
||||
sim = result['similarity']
|
||||
|
||||
print(f"{i}. [{sim:.4f}] {paper['title']}")
|
||||
print(f" #{paper.get('number', 'N/A')} | {paper.get('primary_area', 'N/A')}")
|
||||
print(f" {paper['forum_url']}\n")
|
||||
|
||||
def save(self, results, output_file):
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
json.dump({
|
||||
'model': self.model_name,
|
||||
'total': len(results),
|
||||
'results': results
|
||||
}, f, ensure_ascii=False, indent=2)
|
||||
print(f"Saved to {output_file}")
|
||||
|
||||
38
mcp/data/id_rsa.txt
Normal file
38
mcp/data/id_rsa.txt
Normal file
@@ -0,0 +1,38 @@
|
||||
-----BEGIN OPENSSH PRIVATE KEY-----
|
||||
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAABlwAAAAdzc2gtcn
|
||||
NhAAAAAwEAAQAAAYEA57lv3qJ4z66QO6uxFBnd5QFTsj9P70tO7aSEbgjczT0rgg9+48Ik
|
||||
S/n2m8z4s9C4bk1mTyotJc7p13nveLo0/PAO2Y/6KiSDK4HPMEr8BeWe2RdSBVgfHNls08
|
||||
2eQo/DhW5pbbybKPDI8YOyhijEOF2fDD5I5bA7QUb2Ue8cOo45aPFkFPl6E2j1u9Xaua4+
|
||||
oE0syDzUvMhWZJdZqeQ//Qm1+RzB2+n4y41Ym/5YsQrL6zm4RBUrgSlx4DP6sAx1dPq2OX
|
||||
5XEh6888/QVA55liNukOtOumLjMXhLe5Ut8rur3FyYHmI2jkVpBgAQOErcxvH5UeRCgIh2
|
||||
vUdeAzOk0STzhKon7nIrTPek/SEaM2Kdn9y4+X4tgANTZWTf5M9ELlqthRiff2fHIf++11
|
||||
v/ChqblDaPzSZ+y6myemiRLVouQbrj0Kokvqrv/lL5XzpQrAHQ1PWUB1DUXB5B8W2xsTnB
|
||||
2EZQ7iH4A6VSyzrJb93xTWTjIzytn17PDH0l1JS3AAAFiOMdXKvjHVyrAAAAB3NzaC1yc2
|
||||
EAAAGBAOe5b96ieM+ukDursRQZ3eUBU7I/T+9LTu2khG4I3M09K4IPfuPCJEv59pvM+LPQ
|
||||
uG5NZk8qLSXO6dd573i6NPzwDtmP+iokgyuBzzBK/AXlntkXUgVYHxzZbNPNnkKPw4VuaW
|
||||
28myjwyPGDsoYoxDhdnww+SOWwO0FG9lHvHDqOOWjxZBT5ehNo9bvV2rmuPqBNLMg81LzI
|
||||
VmSXWankP/0Jtfkcwdvp+MuNWJv+WLEKy+s5uEQVK4EpceAz+rAMdXT6tjl+VxIevPPP0F
|
||||
QOeZYjbpDrTrpi4zF4S3uVLfK7q9xcmB5iNo5FaQYAEDhK3Mbx+VHkQoCIdr1HXgMzpNEk
|
||||
84SqJ+5yK0z3pP0hGjNinZ/cuPl+LYADU2Vk3+TPRC5arYUYn39nxyH/vtdb/woam5Q2j8
|
||||
0mfsupsnpokS1aLkG649CqJL6q7/5S+V86UKwB0NT1lAdQ1FweQfFtsbE5wdhGUO4h+AOl
|
||||
Uss6yW/d8U1k4yM8rZ9ezwx9JdSUtwAAAAMBAAEAAAGAMI2mZxvb/IgzKI2dGP0ihW11wA
|
||||
+MDDPXYevq47NvsIF0sFfW2po/SLwjdBnKssK1IkeNfGD1/MoSLVgbWUyK9cTHF8cXP+VO
|
||||
prsYUqIjlIi8c/hy8zO3sS/NocOfuYquCTNNW/T8/eMV96UErx+znavgO4yBcb8va0oXKq
|
||||
vTWmGaneaWdd6gOZjwhF8W6XkdHjGNhJdabAP+Ni2QWAy/a6GxQ3VHGXE49E21l1n/83iz
|
||||
qaH6fimBaBrrBXNev6ycObPIyyXpEbwKi6GbmMbPOiR/DrhTgptpc+TJwBLd4JnX1cqCgX
|
||||
sDiSOog9bgV2okznrxAINMFnrBXD12CXZfdJCsZeDWCxnVngWGImzXk6TGbfvBbyRTIkF9
|
||||
qmW1BdydGrMKQoHiphndWPlJfdRl7r2ASoUkjDSK/hXV/6iYBI5ZRmZSqihFMOQUpYxLu5
|
||||
nz+WecLXZYVfAlIXlESQ3PQJ33/CnDCVqpzjtsQxRYWhA4kVaCMjPnt83LAMDheWlhAAAA
|
||||
wGw2bgn9Ivw8QWSPckU7+TcmemjAVQjbcBmz9aCJlBxHtZiBXa9oQOjDh8Uw84jbiX/6sb
|
||||
uzn2xArZOxWPCd2ZWKyZNodyvI6sQqb4D+xHt4aReWoU5wPDaIZpkuyWzDPSZARmy2k95z
|
||||
Dq995Gl8rW2xkw/f9cTHNf6wvYdvclzKrg1mCdoBUwX1diNI2l7wsww6bDfNYMZcgX82O5
|
||||
aRaIJUJltQ7CXbIow9G2BqquoEjSg6/9ZZ/B0ZWyW+5uuM+wAAAMEA/Z/HZmIuFbmNKC5m
|
||||
tjXCaz9x7oTXl0v+4XMx6smQqklx1XqdXe1YSdbWxJZAhfbLmiOmQIncee/+H7m42zLsFs
|
||||
kgbDtze7+qLi2+MYStd75FypvQ3h+mmYq3ppkBrAiDcJ9UrG7pWUfq+FY6CyOE5ub0mmhm
|
||||
w/DW/I9so8wEi1VBzi0SqpUO6snx77yZoWJhJvlhbEGBvAS/wFIX9MBBefbf5vGMhUT+pW
|
||||
xUIRvizKh/gtySXrj6WPBVoak01AatAAAAwQDp5SPKHHRO/53eC+nVSDK2fc2YWEFkSLQn
|
||||
MCu+pQZv5izoyYPP8FZ4y5qw+16H2f3GbPH8xCDKlokKJlKggDhDV45eWz5UbItDt43okD
|
||||
uB6v9EP4AtKKUNm+GxwhwyoY/C395fe8EvgsAlXNCAy3Wt6cVmbXW+ZSv5JRV9J0GX+5F2
|
||||
K+LjNm4r/1BaLyUOf0eGTvMBc3XEBIKk7MsEBVnfxHmBJQ6fpAScimEM/VrZCbJ9OGKAiq
|
||||
yRuCwKVgZviXMAAAANa29rbzEyNUBoZWFkMQECAwQFBg==
|
||||
-----END OPENSSH PRIVATE KEY-----
|
||||
50
mcp/lifespan.py
Normal file
50
mcp/lifespan.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import asyncssh
|
||||
from contextlib import asynccontextmanager
|
||||
from collections.abc import AsyncIterator
|
||||
from dataclasses import dataclass
|
||||
from starlette.applications import Starlette
|
||||
# --- 1. SSH 连接参数 ---
|
||||
REMOTE_HOST = '202.121.182.208'
|
||||
REMOTE_USER = 'koko125'
|
||||
# 确保路径正确,Windows 路径建议使用正斜杠 '/'
|
||||
PRIVATE_KEY_PATH = 'D:/tool/tool/id_rsa.txt'
|
||||
INITIAL_WORKING_DIRECTORY = f'/cluster/home/{REMOTE_USER}/sandbox'
|
||||
|
||||
|
||||
# --- 2. 定义共享上下文 ---
|
||||
# 这个数据类将持有所有模块需要共享的资源
|
||||
@dataclass
|
||||
class SharedAppContext:
|
||||
"""在整个服务器生命周期内持有的共享资源"""
|
||||
ssh_connection: asyncssh.SSHClientConnection
|
||||
sandbox_path: str
|
||||
|
||||
|
||||
# --- 3. 创建生命周期管理器 ---
|
||||
# 这是整个架构的核心,负责在服务器启动时连接SSH,在关闭时断开
|
||||
@asynccontextmanager
|
||||
async def shared_lifespan(app: Starlette) -> AsyncIterator[SharedAppContext]:
|
||||
"""
|
||||
为整个应用管理共享资源的生命周期。
|
||||
"""
|
||||
print("主应用启动,正在建立共享 SSH 连接...")
|
||||
conn = None
|
||||
try:
|
||||
# 建立 SSH 连接
|
||||
conn = await asyncssh.connect(
|
||||
REMOTE_HOST,
|
||||
username=REMOTE_USER,
|
||||
client_keys=[PRIVATE_KEY_PATH]
|
||||
)
|
||||
print(f"SSH 连接到 {REMOTE_HOST} 成功!")
|
||||
|
||||
# 使用 yield 将创建好的共享资源上下文传递给 Starlette 应用
|
||||
# 服务器会在此处暂停并开始处理请求
|
||||
yield {"shared_context": SharedAppContext(ssh_connection=conn, sandbox_path=INITIAL_WORKING_DIRECTORY)}
|
||||
|
||||
finally:
|
||||
# 当服务器关闭时,yield 之后的代码会被执行
|
||||
if conn:
|
||||
conn.close()
|
||||
await conn.wait_closed()
|
||||
print("共享 SSH 连接已关闭。")
|
||||
55
mcp/main.py
Normal file
55
mcp/main.py
Normal file
@@ -0,0 +1,55 @@
|
||||
# main.py
|
||||
# 将 Materials CIF MCP 与 System Tools MCP 一起挂载到 Starlette。
|
||||
# 关键点:
|
||||
# - 在 lifespan 中启动每个 MCP 的 session_manager.run()(参考 SDK README 的 Starlette 挂载示例与 streamable_http_app 用法 [1])
|
||||
# - 通过 Mount 指定各自的子路径(如 /system 与 /materials)
|
||||
|
||||
import contextlib
|
||||
from starlette.applications import Starlette
|
||||
from starlette.routing import Mount
|
||||
|
||||
from system_tools import create_system_mcp
|
||||
from materialproject_mcp import create_materials_mcp
|
||||
from softBV_remake import create_softbv_mcp
|
||||
from paper_search_mcp import create_paper_search_mcp
|
||||
from topological_analysis_models import create_topological_analysis_mcp
|
||||
|
||||
# 创建 MCP 实例
|
||||
system_mcp = create_system_mcp()
|
||||
materials_mcp = create_materials_mcp()
|
||||
softbv_mcp = create_softbv_mcp()
|
||||
paper_search_mcp = create_paper_search_mcp()
|
||||
topological_analysis_mcp = create_topological_analysis_mcp()
|
||||
# 在 Starlette 的 lifespan 中启动 MCP 的 session manager
|
||||
@contextlib.asynccontextmanager
|
||||
async def lifespan(app: Starlette):
|
||||
async with contextlib.AsyncExitStack() as stack:
|
||||
await stack.enter_async_context(system_mcp.session_manager.run())
|
||||
await stack.enter_async_context(materials_mcp.session_manager.run())
|
||||
await stack.enter_async_context(softbv_mcp.session_manager.run())
|
||||
await stack.enter_async_context(paper_search_mcp.session_manager.run())
|
||||
await stack.enter_async_context(topological_analysis_mcp.session_manager.run())
|
||||
yield # 服务器运行期间
|
||||
# 退出时自动清理
|
||||
|
||||
# 挂载两个 MCP 的 Streamable HTTP App
|
||||
app = Starlette(
|
||||
lifespan=lifespan,
|
||||
routes=[
|
||||
Mount("/system", app=system_mcp.streamable_http_app()),
|
||||
Mount("/materials", app=materials_mcp.streamable_http_app()),
|
||||
Mount("/softBV", app=softbv_mcp.streamable_http_app()),
|
||||
Mount("/papersearch",app=paper_search_mcp.streamable_http_app()),
|
||||
Mount("/topologicalAnalysis",app=topological_analysis_mcp.streamable_http_app()),
|
||||
],
|
||||
)
|
||||
|
||||
# 启动命令(终端执行):
|
||||
# uvicorn main:app --host 0.0.0.0 --port 8000
|
||||
# 访问:
|
||||
# http://localhost:8000/system
|
||||
# http://localhost:8000/materials
|
||||
# http://localhost:8000/softBV
|
||||
# http://localhost:8000/papersearch
|
||||
# http://localhost:8000/topologicalAnalysis
|
||||
# 如果需要浏览器客户端访问(CORS 暴露 Mcp-Session-Id),请参考 README 中的 CORS 配置示例 [1]
|
||||
513
mcp/materialproject_mcp.py
Normal file
513
mcp/materialproject_mcp.py
Normal file
@@ -0,0 +1,513 @@
|
||||
# materials_mp_cif_mcp.py
|
||||
#
|
||||
# 说明:
|
||||
# - 这是一个基于 FastMCP 的“Materials Project CIF”服务器,提供:
|
||||
# 1) search_cifs:使用 Materials Project 的 mp-api 进行检索(两步:summary 过滤 + materials 获取结构)
|
||||
# 2) get_cif:按材料 ID 获取结构并导出 CIF 文本(pymatgen)
|
||||
# 3) filter_cifs:在本地对候选进行硬筛选
|
||||
# 4) download_cif_bulk:批量下载 CIF,带并发与进度
|
||||
# 5) 资源:cif://mp/{id} 直接读取 CIF 文本
|
||||
#
|
||||
# - 依赖:pip install "mcp[cli]" mp_api pymatgen
|
||||
# - API Key:
|
||||
# 默认从环境变量 MP_API_KEY 读取;若未设置则回退为你提供的 Key(仅用于演示,不建议在生产中硬编码)[3]
|
||||
# - MCP/Context/Lifespan 的使用方法参考 SDK 文档(工具、结构化输出、进度等)[1]
|
||||
#
|
||||
# 引用:
|
||||
# - MPRester/MP_API_KEY 的用法与推荐的会话管理方式见 Materials Project 文档 [3]
|
||||
# - summary.search 与 materials.search 的查询与 fields 限制示例见文档 [4]
|
||||
# - FastMCP 的工具、资源、Context、结构化输出见 README 示例 [1]
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
import re
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from mp_api.client import MPRester # [3]
|
||||
from pymatgen.core import Structure # 用于转 CIF
|
||||
from mcp.server.fastmcp import FastMCP, Context # [1]
|
||||
from mcp.server.session import ServerSession # [1]
|
||||
|
||||
# ========= 配置 =========
|
||||
|
||||
# 生产中请通过环境变量 MP_API_KEY 配置;此处为演示而提供回退到你给的 API Key(不建议硬编码)[3]
|
||||
DEFAULT_MP_API_KEY = os.getenv("MP_API_KEY", "X0NqhxkFeupGy7k09HRmLS6PI4VmvTBW")
|
||||
|
||||
# ========= 数据模型 =========
|
||||
|
||||
class LatticeParameters(BaseModel):
|
||||
a: float | None = Field(default=None, description="a 轴长度 (Å)")
|
||||
b: float | None = Field(default=None, description="b 轴长度 (Å)")
|
||||
c: float | None = Field(default=None, description="c 轴长度 (Å)")
|
||||
alpha: float | None = Field(default=None, description="α 角 (度)")
|
||||
beta: float | None = Field(default=None, description="β 角 (度)")
|
||||
gamma: float | None = Field(default=None, description="γ 角 (度)")
|
||||
|
||||
|
||||
class CIFCandidate(BaseModel):
|
||||
id: str = Field(description="Materials Project ID,如 mp-149")
|
||||
source: Literal["mp"] = Field(default="mp", description="来源库")
|
||||
chemical_formula: str | None = Field(default=None, description="化学式(pretty)")
|
||||
space_group: str | None = Field(default=None, description="空间群符号")
|
||||
cell_parameters: LatticeParameters | None = Field(default=None, description="晶格参数")
|
||||
bandgap: float | None = Field(default=None, description="带隙 (eV)")
|
||||
stability_energy: float | None = Field(default=None, description="E_above_hull (eV)")
|
||||
magnetic_ordering: str | None = Field(default=None, description="磁性信息")
|
||||
has_cif: bool = Field(default=True, description="是否存在可导出的 CIF(只要有结构即可)")
|
||||
extras: dict[str, Any] | None = Field(default=None, description="原始文档片段(健壮解析)")
|
||||
|
||||
|
||||
class SearchQuery(BaseModel):
|
||||
# 基础化学约束
|
||||
include_elements: list[str] | None = Field(default=None, description="包含元素集合(如 ['Si','O'])") # [4]
|
||||
exclude_elements: list[str] | None = Field(default=None, description="排除元素集合(如 ['Li'])")
|
||||
formula: str | None = Field(default=None, description="化学式片段(本地二次过滤)")
|
||||
|
||||
# 结构与性能约束(summary 可过滤带隙与 E_above_hull)[4]
|
||||
bandgap_min: float | None = None
|
||||
bandgap_max: float | None = None
|
||||
e_above_hull_max: float | None = None
|
||||
|
||||
# 空间群、晶格范围(在本地二次过滤)
|
||||
space_group: str | None = None
|
||||
cell_a_min: float | None = None
|
||||
cell_a_max: float | None = None
|
||||
cell_b_min: float | None = None
|
||||
cell_b_max: float | None = None
|
||||
cell_c_min: float | None = None
|
||||
cell_c_max: float | None = None
|
||||
|
||||
magnetic_ordering: str | None = None # 本地二次过滤
|
||||
|
||||
page: int = Field(default=1, ge=1, description="页码(仅本地分页;MP 端点请自行扩展)")
|
||||
per_page: int = Field(default=50, ge=1, le=200, description="每页数量(本地分页)")
|
||||
sort_by: str | None = Field(default=None, description="排序字段(bandgap/e_above_hull等,本地排序)")
|
||||
sort_order: Literal["asc", "desc"] | None = Field(default=None, description="排序方向")
|
||||
|
||||
|
||||
class CIFText(BaseModel):
|
||||
id: str
|
||||
source: Literal["mp"]
|
||||
text: str = Field(description="CIF 文件文本内容")
|
||||
metadata: dict[str, Any] | None = Field(default=None, description="附加元数据(如结构来源)")
|
||||
|
||||
|
||||
class FilterCriteria(BaseModel):
|
||||
include_elements: list[str] | None = None
|
||||
exclude_elements: list[str] | None = None
|
||||
formula: str | None = None
|
||||
space_group: str | None = None
|
||||
cell_a_min: float | None = None
|
||||
cell_a_max: float | None = None
|
||||
cell_b_min: float | None = None
|
||||
cell_b_max: float | None = None
|
||||
cell_c_min: float | None = None
|
||||
cell_c_max: float | None = None
|
||||
bandgap_min: float | None = None
|
||||
bandgap_max: float | None = None
|
||||
e_above_hull_max: float | None = None
|
||||
magnetic_ordering: str | None = None
|
||||
|
||||
|
||||
class BulkDownloadItem(BaseModel):
|
||||
id: str
|
||||
source: Literal["mp"] = "mp"
|
||||
success: bool
|
||||
text: str | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class BulkDownloadResult(BaseModel):
|
||||
items: list[BulkDownloadItem]
|
||||
total: int
|
||||
succeeded: int
|
||||
failed: int
|
||||
|
||||
|
||||
# ========= Lifespan =========
|
||||
|
||||
@dataclass
|
||||
class AppContext:
|
||||
mpr: MPRester
|
||||
api_key: str
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_server: FastMCP) -> AsyncIterator[AppContext]:
|
||||
# 推荐使用 MP_API_KEY 环境变量;此处回退到提供的 Key 字符串 [3]
|
||||
api_key = DEFAULT_MP_API_KEY
|
||||
mpr = MPRester(api_key) # 会话管理在 mp-api 内部 [3]
|
||||
try:
|
||||
yield AppContext(mpr=mpr, api_key=api_key)
|
||||
finally:
|
||||
# mp-api 未强制要求手动关闭;若未来提供关闭方法,可在此清理 [3]
|
||||
pass
|
||||
|
||||
|
||||
# ========= 服务器 =========
|
||||
|
||||
mcp = FastMCP(
|
||||
name="Materials Project CIF",
|
||||
instructions="通过 Materials Project API 检索并导出 CIF 文件。",
|
||||
lifespan=lifespan,
|
||||
streamable_http_path="/", # 便于在 Starlette 下挂载到子路径 [1]
|
||||
)
|
||||
|
||||
|
||||
# ========= 工具实现 =========
|
||||
|
||||
def _match_elements_set(text: str | None, required: list[str] | None, excluded: list[str] | None) -> bool:
|
||||
if text is None:
|
||||
return not required and not excluded
|
||||
s = text.upper()
|
||||
if required:
|
||||
for el in required:
|
||||
if el.upper() not in s:
|
||||
return False
|
||||
if excluded:
|
||||
for el in excluded:
|
||||
if el.upper() in s:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _lattice_from_structure(struct: Structure | None) -> LatticeParameters | None:
|
||||
if struct is None:
|
||||
return None
|
||||
lat = struct.lattice
|
||||
return LatticeParameters(
|
||||
a=float(lat.a),
|
||||
b=float(lat.b),
|
||||
c=float(lat.c),
|
||||
alpha=float(lat.alpha),
|
||||
beta=float(lat.beta),
|
||||
gamma=float(lat.gamma),
|
||||
)
|
||||
|
||||
|
||||
def _sg_symbol_from_doc(doc: Any) -> str | None:
|
||||
"""
|
||||
尝试从 materials.search 返回的文档中提取空间群符号。
|
||||
MP 文档中 symmetry 通常包含 spacegroup 信息(不同版本字段路径可能变化)[4]
|
||||
"""
|
||||
try:
|
||||
sym = getattr(doc, "symmetry", None)
|
||||
if sym and getattr(sym, "spacegroup", None):
|
||||
sg = sym.spacegroup
|
||||
# 常见字段:symbol 或 international
|
||||
return getattr(sg, "symbol", None) or getattr(sg, "international", None)
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
async def _to_thread(fn, *args, **kwargs):
|
||||
return await asyncio.to_thread(fn, *args, **kwargs)
|
||||
|
||||
|
||||
def _parse_elements_from_formula(formula: str) -> list[str]:
|
||||
"""从化学式中解析元素符号(简易版)。"""
|
||||
if not formula:
|
||||
return []
|
||||
elems = re.findall(r"[A-Z][a-z]?", formula)
|
||||
seen, ordered = set(), []
|
||||
for e in elems:
|
||||
if e not in seen:
|
||||
seen.add(e)
|
||||
ordered.append(e)
|
||||
return ordered
|
||||
|
||||
@mcp.tool()
|
||||
async def search_cifs(
|
||||
query: SearchQuery,
|
||||
ctx: Context[ServerSession, AppContext],
|
||||
) -> list[CIFCandidate]:
|
||||
"""
|
||||
通过 summary 做初筛(限制返回数量),再用 materials 拉取当前页的结构与对称性。
|
||||
修复点:
|
||||
- 仅请求当前页(chunk_size/per_page + num_chunks=1),避免全库下载超时
|
||||
- 绝不在返回值中包含 mp-api 文档对象或 Structure,extras 只放可序列化轻量字段
|
||||
"""
|
||||
app = ctx.request_context.lifespan_context
|
||||
mpr = app.mpr
|
||||
|
||||
# 将 formula 自动转为元素集合(若用户未显式提供)
|
||||
include_elements = query.include_elements or _parse_elements_from_formula(query.formula or "")
|
||||
|
||||
# 至少要有一个约束,避免全库扫描
|
||||
if not include_elements and not query.exclude_elements and query.bandgap_min is None \
|
||||
and query.bandgap_max is None and query.e_above_hull_max is None:
|
||||
raise ValueError("请至少提供一个过滤条件,例如 include_elements、bandgap 或 energy_above_hull 上限,以避免超时。")
|
||||
|
||||
# 1) summary 初筛:只取“当前页”的文档(轻字段)
|
||||
summary_kwargs: dict[str, Any] = {}
|
||||
if include_elements:
|
||||
summary_kwargs["elements"] = include_elements # MP 支持按元素集合检索
|
||||
if query.exclude_elements:
|
||||
summary_kwargs["exclude_elements"] = query.exclude_elements
|
||||
if query.bandgap_min is not None or query.bandgap_max is not None:
|
||||
summary_kwargs["band_gap"] = (
|
||||
float(query.bandgap_min) if query.bandgap_min is not None else None,
|
||||
float(query.bandgap_max) if query.bandgap_max is not None else None,
|
||||
)
|
||||
if query.e_above_hull_max is not None:
|
||||
summary_kwargs["energy_above_hull"] = (None, float(query.e_above_hull_max))
|
||||
|
||||
fields_summary = [
|
||||
"material_id",
|
||||
"formula_pretty",
|
||||
"band_gap",
|
||||
"energy_above_hull",
|
||||
"ordering",
|
||||
]
|
||||
|
||||
# 关键:限制 summary 返回规模
|
||||
try:
|
||||
summary_docs = await asyncio.to_thread(
|
||||
mpr.materials.summary.search,
|
||||
fields=fields_summary,
|
||||
chunk_size=query.per_page, # 每块(页)大小
|
||||
num_chunks=1, # 只取一块 => 只取当前页数量
|
||||
**summary_kwargs,
|
||||
)
|
||||
except TypeError:
|
||||
# 如果你的 mp-api 版本不支持上述参数,提示升级,避免一次性抓回十几万条导致超时
|
||||
raise RuntimeError(
|
||||
"当前 mp-api 版本不支持 chunk_size/num_chunks,请升级 mp-api(pip install -U mp_api)。"
|
||||
)
|
||||
|
||||
if not summary_docs:
|
||||
await ctx.debug("summary 检索结果为空")
|
||||
return []
|
||||
|
||||
# 本地排序(如需)
|
||||
results_all: list[CIFCandidate] = []
|
||||
for sdoc in summary_docs:
|
||||
e_hull = getattr(sdoc, "energy_above_hull", None)
|
||||
bandgap = getattr(sdoc, "band_gap", None)
|
||||
ordering = getattr(sdoc, "ordering", None)
|
||||
|
||||
# 暂不放 structure/doc 到 extras,避免序列化错误
|
||||
results_all.append(
|
||||
CIFCandidate(
|
||||
id=sdoc.material_id,
|
||||
source="mp",
|
||||
chemical_formula=getattr(sdoc, "formula_pretty", None),
|
||||
space_group=None,
|
||||
cell_parameters=None,
|
||||
bandgap=float(bandgap) if bandgap is not None else None,
|
||||
stability_energy=float(e_hull) if e_hull is not None else None,
|
||||
magnetic_ordering=ordering,
|
||||
has_cif=True,
|
||||
extras={
|
||||
"summary": {
|
||||
"material_id": sdoc.material_id,
|
||||
"formula_pretty": getattr(sdoc, "formula_pretty", None),
|
||||
"band_gap": float(bandgap) if bandgap is not None else None,
|
||||
"energy_above_hull": float(e_hull) if e_hull is not None else None,
|
||||
"ordering": ordering,
|
||||
}
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# 本地分页:summary 已限制返回条数,这里只做安全截断和排序
|
||||
if query.sort_by:
|
||||
reverse = query.sort_order == "desc"
|
||||
key_map = {
|
||||
"bandgap": lambda c: (c.bandgap if c.bandgap is not None else float("inf")),
|
||||
"e_above_hull": lambda c: (c.stability_energy if c.stability_energy is not None else float("inf")),
|
||||
"energy_above_hull": lambda c: (c.stability_energy if c.stability_energy is not None else float("inf")),
|
||||
}
|
||||
if query.sort_by in key_map:
|
||||
results_all.sort(key=key_map[query.sort_by], reverse=reverse)
|
||||
|
||||
page_slice = results_all[: query.per_page]
|
||||
page_ids = [c.id for c in page_slice]
|
||||
|
||||
# 2) 只为当前页材料获取结构与对称性
|
||||
if not page_ids:
|
||||
return []
|
||||
|
||||
fields_materials = ["material_id", "structure", "symmetry"]
|
||||
materials_docs = await asyncio.to_thread(
|
||||
mpr.materials.search,
|
||||
material_ids=page_ids,
|
||||
fields=fields_materials,
|
||||
)
|
||||
mat_by_id = {doc.material_id: doc for doc in materials_docs}
|
||||
|
||||
# 合并结构/空间群,仍然不返回不可序列化对象
|
||||
merged: list[CIFCandidate] = []
|
||||
for c in page_slice:
|
||||
mdoc = mat_by_id.get(c.id)
|
||||
struct = getattr(mdoc, "structure", None) if mdoc else None
|
||||
|
||||
# 空间群符号
|
||||
sg = None
|
||||
try:
|
||||
sym = getattr(mdoc, "symmetry", None)
|
||||
if sym and getattr(sym, "spacegroup", None):
|
||||
sgobj = sym.spacegroup
|
||||
sg = getattr(sgobj, "symbol", None) or getattr(sgobj, "international", None)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
c.space_group = sg
|
||||
c.cell_parameters = _lattice_from_structure(struct)
|
||||
c.has_cif = struct is not None
|
||||
# 保持 extras 为轻量 JSON
|
||||
c.extras = {
|
||||
**(c.extras or {}),
|
||||
"materials": {"has_structure": struct is not None},
|
||||
}
|
||||
merged.append(c)
|
||||
|
||||
await ctx.debug(f"返回 {len(merged)} 条(已限制每页 {query.per_page} 条)")
|
||||
return merged
|
||||
|
||||
@mcp.tool()
|
||||
async def get_cif(
|
||||
id: str,
|
||||
ctx: Context[ServerSession, AppContext],
|
||||
) -> CIFText:
|
||||
"""
|
||||
使用 materials.search 获取结构并导出 CIF 文本。
|
||||
- fields=["structure"] 以提升速度 [4]
|
||||
- 结构到 CIF 通过 pymatgen 的 Structure.to(fmt="cif")
|
||||
"""
|
||||
app = ctx.request_context.lifespan_context
|
||||
mpr = app.mpr
|
||||
|
||||
await ctx.info(f"获取结构并导出 CIF:{id}")
|
||||
docs = await _to_thread(mpr.materials.search, material_ids=[id], fields=["material_id", "structure"]) # [4]
|
||||
if not docs:
|
||||
raise ValueError(f"未找到材料:{id}")
|
||||
|
||||
doc = docs[0]
|
||||
struct: Structure | None = getattr(doc, "structure", None)
|
||||
if struct is None:
|
||||
raise ValueError(f"材料 {id} 无结构数据,无法导出 CIF")
|
||||
|
||||
cif_text = struct.to(fmt="cif") # 直接输出 CIF 文本
|
||||
return CIFText(id=id, source="mp", text=cif_text, metadata={"from": "materials.search"})
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def filter_cifs(
|
||||
candidates: list[CIFCandidate],
|
||||
criteria: FilterCriteria,
|
||||
ctx: Context[ServerSession, AppContext],
|
||||
) -> list[CIFCandidate]:
|
||||
def in_range(v: float | None, vmin: float | None, vmax: float | None) -> bool:
|
||||
if v is None:
|
||||
return vmin is None and vmax is None
|
||||
if vmin is not None and v < vmin:
|
||||
return False
|
||||
if vmax is not None and v > vmax:
|
||||
return False
|
||||
return True
|
||||
|
||||
filtered: list[CIFCandidate] = []
|
||||
for c in candidates:
|
||||
# 元素集合近似匹配(基于化学式字符串)
|
||||
if not _match_elements_set(c.chemical_formula, criteria.include_elements, criteria.exclude_elements):
|
||||
continue
|
||||
# 化学式包含判断
|
||||
if criteria.formula:
|
||||
if not c.chemical_formula or criteria.formula.upper() not in c.chemical_formula.upper():
|
||||
continue
|
||||
# 空间群
|
||||
if criteria.space_group:
|
||||
spg = (c.space_group or "").lower()
|
||||
if criteria.space_group.lower() not in spg:
|
||||
continue
|
||||
# 晶格范围
|
||||
cp = c.cell_parameters or LatticeParameters()
|
||||
if not in_range(cp.a, criteria.cell_a_min, criteria.cell_a_max):
|
||||
continue
|
||||
if not in_range(cp.b, criteria.cell_b_min, criteria.cell_b_max):
|
||||
continue
|
||||
if not in_range(cp.c, criteria.cell_c_min, criteria.cell_c_max):
|
||||
continue
|
||||
# 带隙与稳定性
|
||||
if not in_range(c.bandgap, criteria.bandgap_min, criteria.bandgap_max):
|
||||
continue
|
||||
if criteria.e_above_hull_max is not None:
|
||||
if c.stability_energy is None or c.stability_energy > criteria.e_above_hull_max:
|
||||
continue
|
||||
# 磁性
|
||||
if criteria.magnetic_ordering:
|
||||
mo = "" # 当前未填充,保留兼容
|
||||
if criteria.magnetic_ordering.upper() not in mo.upper():
|
||||
continue
|
||||
|
||||
filtered.append(c)
|
||||
|
||||
await ctx.debug(f"筛选后:{len(filtered)} / 原始 {len(candidates)}")
|
||||
return filtered
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def download_cif_bulk(
|
||||
ids: list[str],
|
||||
concurrency: int = 4,
|
||||
ctx: Context[ServerSession, AppContext] | None = None,
|
||||
) -> BulkDownloadResult:
|
||||
semaphore = asyncio.Semaphore(max(1, concurrency))
|
||||
items: list[BulkDownloadItem] = []
|
||||
|
||||
async def worker(mid: str, idx: int) -> None:
|
||||
nonlocal items
|
||||
try:
|
||||
async with semaphore:
|
||||
# 复用当前 ctx
|
||||
assert ctx is not None
|
||||
cif = await get_cif(id=mid, ctx=ctx)
|
||||
items.append(BulkDownloadItem(id=mid, success=True, text=cif.text))
|
||||
except Exception as e:
|
||||
items.append(BulkDownloadItem(id=mid, success=False, error=str(e)))
|
||||
if ctx is not None:
|
||||
progress = (idx + 1) / len(ids) if len(ids) > 0 else 1.0
|
||||
await ctx.report_progress(progress=progress, total=1.0, message=f"{idx + 1}/{len(ids)}") # [1]
|
||||
|
||||
tasks = [worker(mid, i) for i, mid in enumerate(ids)]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
succeeded = sum(1 for it in items if it.success)
|
||||
failed = len(items) - succeeded
|
||||
return BulkDownloadResult(items=items, total=len(items), succeeded=succeeded, failed=failed)
|
||||
|
||||
|
||||
# ========= 资源:cif://mp/{id} =========
|
||||
|
||||
@mcp.resource("cif://mp/{id}")
|
||||
async def cif_resource(id: str) -> str:
|
||||
"""
|
||||
基于 URI 的 CIF 内容读取。资源函数无法直接拿到 Context,
|
||||
这里简单地新的 MPRester 会话读取结构并转 CIF(仅资源路径用途)。
|
||||
"""
|
||||
# 资源中使用一次性会话(不复用 lifespan),适合偶发读取 [1]
|
||||
mpr = MPRester(DEFAULT_MP_API_KEY) # [3]
|
||||
try:
|
||||
docs = mpr.materials.search(material_ids=[id], fields=["material_id", "structure"]) # [4]
|
||||
if not docs or getattr(docs[0], "structure", None) is None:
|
||||
raise ValueError(f"材料 {id} 无结构数据")
|
||||
struct: Structure = docs[0].structure
|
||||
return struct.to(fmt="cif")
|
||||
finally:
|
||||
pass
|
||||
|
||||
|
||||
def create_materials_mcp() -> FastMCP:
|
||||
"""供 Starlette 挂载的工厂函数。"""
|
||||
return mcp
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 独立运行(非挂载模式),使用 Streamable HTTP 传输 [1]
|
||||
mcp.run(transport="streamable-http")
|
||||
172
mcp/paper_search_mcp.py
Normal file
172
mcp/paper_search_mcp.py
Normal file
@@ -0,0 +1,172 @@
|
||||
# paper_search_mcp.py (重构版)
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
# 假设您的 search.py 提供了 PaperSearcher 类
|
||||
from SearchPaperByEmbedding.search import PaperSearcher
|
||||
|
||||
from mcp.server.fastmcp import FastMCP, Context
|
||||
from mcp.server.session import ServerSession
|
||||
|
||||
# ========= 1. 配置与常量 (修正版) =========
|
||||
|
||||
# 获取当前脚本文件所在的目录的绝对路径
|
||||
_CURRENT_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
|
||||
# 基于项目根目录构造资源文件的绝对路径
|
||||
DATA_DIR = os.path.join(_CURRENT_SCRIPT_DIR,"SearchPaperByEmbedding")
|
||||
PAPERS_JSON_PATH = os.path.join(DATA_DIR, "iclr2026_papers.json")
|
||||
|
||||
# 打印路径用于调试,确保它是正确的
|
||||
print(f"DEBUG: Trying to load papers from: {PAPERS_JSON_PATH}")
|
||||
|
||||
MODEL_TYPE = os.getenv("EMBEDDING_MODEL_TYPE", "local")
|
||||
|
||||
|
||||
# ========= 2. 数据模型 =========
|
||||
|
||||
class Paper(BaseModel):
|
||||
"""单篇论文的详细信息"""
|
||||
title: str
|
||||
authors: list[str]
|
||||
abstract: str
|
||||
pdf_url: str | None = None
|
||||
similarity_score: float = Field(description="与查询的相似度分数,越高越相关")
|
||||
|
||||
|
||||
class SearchQuery(BaseModel):
|
||||
"""论文搜索的查询结构"""
|
||||
title: str
|
||||
abstract: str | None = Field(None, description="可选的摘要,以提供更丰富的上下文")
|
||||
|
||||
|
||||
# ========= 3. Lifespan (资源加载与管理) =========
|
||||
|
||||
@dataclass
|
||||
class PaperSearchContext:
|
||||
"""在服务器生命周期内共享的、已初始化的 PaperSearcher 实例"""
|
||||
searcher: PaperSearcher
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def paper_search_lifespan(_server: FastMCP) -> AsyncIterator[PaperSearchContext]:
|
||||
"""
|
||||
FastMCP 生命周期:在服务器启动时初始化 PaperSearcher 并计算 Embeddings。
|
||||
这确保了最耗时的部分只在启动时执行一次。
|
||||
"""
|
||||
print(f"正在使用 '{MODEL_TYPE}' 模型初始化论文搜索引擎...")
|
||||
|
||||
# 使用 to_thread 运行同步的初始化代码
|
||||
def _initialize_searcher():
|
||||
# 1. 初始化
|
||||
searcher = PaperSearcher(
|
||||
PAPERS_JSON_PATH,
|
||||
model_type=MODEL_TYPE
|
||||
)
|
||||
# 2. 计算/加载 Embeddings
|
||||
print("正在计算或加载论文 Embeddings...")
|
||||
searcher.compute_embeddings()
|
||||
print("Embeddings 加载完成。")
|
||||
return searcher
|
||||
|
||||
searcher = await asyncio.to_thread(_initialize_searcher)
|
||||
print("论文搜索引擎已准备就绪!")
|
||||
|
||||
try:
|
||||
# 使用 yield 将创建好的共享资源上下文传递给 MCP
|
||||
yield PaperSearchContext(searcher=searcher)
|
||||
finally:
|
||||
print("论文搜索引擎服务关闭。")
|
||||
|
||||
|
||||
# ========= 4. MCP 服务器实例 =========
|
||||
|
||||
mcp = FastMCP(
|
||||
name="Paper Search",
|
||||
instructions="通过 embedding 相似度检索 ICLR 2026 论文的工具。",
|
||||
lifespan=paper_search_lifespan,
|
||||
streamable_http_path="/",
|
||||
)
|
||||
|
||||
|
||||
# ========= 5. 工具实现 =========
|
||||
@mcp.tool()
|
||||
async def search_papers(
|
||||
queries: list[SearchQuery],
|
||||
top_k: int = 3,
|
||||
ctx: Context[ServerSession, PaperSearchContext] | None = None,
|
||||
) -> list[Paper]:
|
||||
"""
|
||||
根据给定的一个或多个查询(包含标题和可选的摘要),使用语义相似度搜索论文。
|
||||
返回最相关的 top_k 个结果。
|
||||
"""
|
||||
if ctx is None:
|
||||
raise ValueError("Context is required for this operation.")
|
||||
if not queries:
|
||||
return []
|
||||
|
||||
app_ctx = ctx.request_context.lifespan_context
|
||||
searcher = app_ctx.searcher
|
||||
|
||||
# 预处理查询,确保 title 和 abstract 至少有一个存在
|
||||
examples_to_search = []
|
||||
for q in queries:
|
||||
query_dict = q.model_dump(exclude_none=True)
|
||||
if "title" not in query_dict and "abstract" in query_dict:
|
||||
query_dict["title"] = query_dict["abstract"]
|
||||
elif "abstract" not in query_dict and "title" in query_dict:
|
||||
query_dict["abstract"] = query_dict["title"]
|
||||
examples_to_search.append(query_dict)
|
||||
|
||||
if not examples_to_search:
|
||||
return []
|
||||
|
||||
query_display = queries[0].title or queries[0].abstract or "Empty Query"
|
||||
await ctx.info(f"正在以查询 '{query_display}' 等内容搜索排名前 {top_k} 的论文...")
|
||||
|
||||
# 调用底层的 search 方法
|
||||
results = await asyncio.to_thread(
|
||||
searcher.search,
|
||||
examples=examples_to_search,
|
||||
top_k=top_k
|
||||
)
|
||||
|
||||
# --- 关键修正:按照 search.py 返回的正确结构进行解析 ---
|
||||
papers = []
|
||||
for res in results:
|
||||
paper_data = res.get('paper', {}) # 获取嵌套的论文信息字典
|
||||
similarity = res.get('similarity', 0.0) # 获取相似度分数
|
||||
|
||||
papers.append(
|
||||
Paper(
|
||||
title=paper_data.get("title", "N/A"),
|
||||
authors=paper_data.get("authors", []),
|
||||
abstract=paper_data.get("abstract", ""),
|
||||
# 注意:原始数据中 pdf url 的键是 'forum_url'
|
||||
pdf_url=paper_data.get("forum_url"),
|
||||
similarity_score=similarity # 使用正确的相似度分数
|
||||
)
|
||||
)
|
||||
|
||||
await ctx.debug(f"找到 {len(papers)} 篇相关论文。")
|
||||
return papers
|
||||
|
||||
|
||||
# ========= 6. 工厂函数 (用于 Starlette 集成) =========
|
||||
|
||||
def create_paper_search_mcp() -> FastMCP:
|
||||
"""供 Starlette 挂载的工厂函数。"""
|
||||
return mcp
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 如果直接运行此文件,则启动一个独立的 MCP 服务器
|
||||
mcp.run(transport="streamable-http")
|
||||
221
mcp/server.py
Normal file
221
mcp/server.py
Normal file
@@ -0,0 +1,221 @@
|
||||
import asyncssh
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from collections.abc import AsyncIterator
|
||||
from dataclasses import dataclass
|
||||
|
||||
from mcp.server.fastmcp import FastMCP, Context
|
||||
from mcp.server.session import ServerSession
|
||||
|
||||
# --- 1. 使用您提供的参数 ---
|
||||
REMOTE_HOST = '202.121.182.208'
|
||||
REMOTE_USER = 'koko125'
|
||||
PRIVATE_KEY_PATH = 'D:/tool/tool/id_rsa.txt' # Windows 路径建议使用 /
|
||||
INITIAL_WORKING_DIRECTORY = '/cluster/home/koko125/sandbox'
|
||||
|
||||
|
||||
# --- 2. 生命周期管理部分 (高级功能) ---
|
||||
|
||||
@dataclass
|
||||
class AppContext:
|
||||
"""用于在服务器生命周期内持有共享资源的上下文对象"""
|
||||
ssh_connection: asyncssh.SSHClientConnection
|
||||
current_path: str
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def app_lifespan(server: FastMCP) -> AsyncIterator[AppContext]:
|
||||
"""在服务器启动时建立SSH连接,在关闭时断开。"""
|
||||
print("服务器启动中,正在建立SSH连接...")
|
||||
conn = None
|
||||
try:
|
||||
conn = await asyncssh.connect(
|
||||
REMOTE_HOST,
|
||||
username=REMOTE_USER,
|
||||
client_keys=[PRIVATE_KEY_PATH]
|
||||
)
|
||||
print("SSH 连接成功!")
|
||||
|
||||
# <<< 在登录时执行 source bashrc 命令 >>>
|
||||
# 注意:bashrc 的效果(如环境变量)只会对这一个 conn.run() 调用生效。
|
||||
# 对于非交互式shell,更好的做法是把需要执行的命令串联起来。
|
||||
# 但如果只是为了加载路径等,可以在后续命令中体现。
|
||||
print("正在加载 .bashrc...")
|
||||
# 为了让 bashrc 生效,我们需要在一个交互式 shell 中执行命令
|
||||
await conn.run('source /cluster/home/koko125/.bashrc', check=True)
|
||||
print(".bashrc 加载完成。")
|
||||
|
||||
print(f"初始工作目录设置为: {INITIAL_WORKING_DIRECTORY}")
|
||||
yield AppContext(ssh_connection=conn, current_path=INITIAL_WORKING_DIRECTORY)
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
await conn.wait_closed()
|
||||
print("SSH 连接已关闭。")
|
||||
|
||||
|
||||
# --- 3. 创建 MCP 服务器实例,并应用生命周期管理 ---
|
||||
mcp = FastMCP("远程服务器工具", lifespan=app_lifespan,port=8000)
|
||||
|
||||
|
||||
# --- 4. 保留的调试工具 (不安全) ---
|
||||
|
||||
@mcp.tool()
|
||||
async def execute_remote_command(command: str, ctx: Context) -> str:
|
||||
"""
|
||||
【调试用】在远程服务器上执行一个任意的shell命令并返回其输出。
|
||||
警告:此工具有安全风险,请勿在生产环境中使用。
|
||||
"""
|
||||
await ctx.info(f"准备在 {REMOTE_HOST} 上执行调试命令: '{command}'")
|
||||
try:
|
||||
# 每次都创建一个新连接,以确保环境隔离
|
||||
async with asyncssh.connect(REMOTE_HOST, username=REMOTE_USER, client_keys=[PRIVATE_KEY_PATH]) as conn:
|
||||
# 在执行命令前先 source bashrc
|
||||
full_command = f'source /cluster/home/koko125/.bashrc && {command}'
|
||||
result = await conn.run(full_command, check=True)
|
||||
output = result.stdout.strip()
|
||||
await ctx.info("调试命令成功执行。")
|
||||
return output
|
||||
except asyncssh.ProcessError as e:
|
||||
error_message = f"命令执行失败: {e.stderr.strip()}"
|
||||
await ctx.error(error_message)
|
||||
return error_message
|
||||
except Exception as e:
|
||||
error_message = f"发生未知错误: {str(e)}"
|
||||
await ctx.error(error_message)
|
||||
return error_message
|
||||
|
||||
|
||||
# --- 5. 新增的规范、安全的工具 ---
|
||||
|
||||
@mcp.tool()
|
||||
async def list_files_in_sandbox(ctx: Context[ServerSession, AppContext], path: str = ".") -> str:
|
||||
"""
|
||||
【安全】列出远程服务器安全沙箱路径下的文件和目录。
|
||||
|
||||
Args:
|
||||
path: 要查看的相对路径,默认为当前沙箱目录。
|
||||
"""
|
||||
try:
|
||||
app_ctx = ctx.request_context.lifespan_context
|
||||
conn = app_ctx.ssh_connection
|
||||
|
||||
# 防止目录穿越攻击
|
||||
if ".." in path.split('/'):
|
||||
return "错误: 不允许使用 '..' 访问上级目录。"
|
||||
|
||||
# 安全地拼接路径
|
||||
target_path = os.path.join(app_ctx.current_path, path)
|
||||
|
||||
await ctx.info(f"正在列出安全路径: {target_path}")
|
||||
|
||||
# 为了让 .bashrc 的设置(如别名)生效,最好也加上 source
|
||||
command_to_run = f'source /cluster/home/koko125/.bashrc && ls -l {conn.escape(target_path)}'
|
||||
result = await conn.run(command_to_run, check=True)
|
||||
return result.stdout.strip()
|
||||
except Exception as e:
|
||||
await ctx.error(f"执行 ls 命令失败: {e}")
|
||||
return f"错误: {e}"
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def create_directory_in_sandbox(ctx: Context[ServerSession, AppContext], directory_name: str) -> str:
|
||||
"""
|
||||
【安全】在远程服务器的安全沙箱内创建一个新目录。
|
||||
|
||||
Args:
|
||||
directory_name: 要创建的目录的名称。
|
||||
"""
|
||||
try:
|
||||
app_ctx = ctx.request_context.lifespan_context
|
||||
conn = app_ctx.ssh_connection
|
||||
|
||||
# 安全检查:确保目录名不包含斜杠或 "..",以防止创建深层目录或进行目录穿越
|
||||
if '/' in directory_name or ".." in directory_name:
|
||||
return "错误: 目录名无效。不能包含 '/' 或 '..'。"
|
||||
|
||||
target_path = os.path.join(app_ctx.current_path, directory_name)
|
||||
|
||||
await ctx.info(f"正在创建远程目录: {target_path}")
|
||||
|
||||
command_to_run = f'source /cluster/home/koko125/.bashrc && mkdir {conn.escape(target_path)}'
|
||||
await conn.run(command_to_run, check=True)
|
||||
|
||||
return f"目录 '{directory_name}' 在沙箱中成功创建。"
|
||||
except asyncssh.ProcessError as e:
|
||||
error_message = f"创建目录失败: {e.stderr.strip()}"
|
||||
await ctx.error(error_message)
|
||||
return error_message
|
||||
except Exception as e:
|
||||
await ctx.error(f"创建目录时发生未知错误: {e}")
|
||||
return f"错误: {e}"
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def read_file_in_sandbox(ctx: Context[ServerSession, AppContext], file_path: str) -> str:
|
||||
"""
|
||||
【安全】读取远程服务器安全沙箱内指定文件的内容。
|
||||
|
||||
Args:
|
||||
file_path: 相对于沙箱目录的文件路径。
|
||||
"""
|
||||
try:
|
||||
app_ctx = ctx.request_context.lifespan_context
|
||||
conn = app_ctx.ssh_connection
|
||||
|
||||
# 安全检查:防止目录穿越
|
||||
if ".." in file_path.split('/'):
|
||||
return "错误: 不允许使用 '..' 访问上级目录。"
|
||||
|
||||
target_path = os.path.join(app_ctx.current_path, file_path)
|
||||
|
||||
await ctx.info(f"正在读取远程文件: {target_path}")
|
||||
|
||||
async with conn.start_sftp_client() as sftp:
|
||||
async with sftp.open(target_path, 'r') as f:
|
||||
content = await f.read()
|
||||
return content
|
||||
|
||||
except FileNotFoundError:
|
||||
error_message = f"读取文件失败: 文件 '{file_path}' 不存在。"
|
||||
await ctx.error(error_message)
|
||||
return error_message
|
||||
except Exception as e:
|
||||
await ctx.error(f"读取文件时发生未知错误: {e}")
|
||||
return f"错误: {e}"
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def write_file_in_sandbox(ctx: Context[ServerSession, AppContext], file_path: str, content: str) -> str:
|
||||
"""
|
||||
【安全】向远程服务器安全沙箱内的文件写入内容。如果文件已存在,则会覆盖它。
|
||||
|
||||
Args:
|
||||
file_path: 相对于沙箱目录的文件路径。
|
||||
content: 要写入文件的文本内容。
|
||||
"""
|
||||
try:
|
||||
app_ctx = ctx.request_context.lifespan_context
|
||||
conn = app_ctx.ssh_connection
|
||||
normalized_file_path = file_path.replace("\\", "/")
|
||||
# 安全检查
|
||||
if ".." in file_path.split('/'):
|
||||
return "错误: 不允许使用 '..' 访问上级目录。"
|
||||
|
||||
target_path = os.path.join(app_ctx.current_path, normalized_file_path)
|
||||
|
||||
await ctx.info(f"正在向远程文件写入: {target_path}")
|
||||
|
||||
async with conn.start_sftp_client() as sftp:
|
||||
async with sftp.open(target_path, 'w') as f:
|
||||
await f.write(content)
|
||||
|
||||
return f"内容已成功写入到沙箱文件 '{file_path}'。"
|
||||
except Exception as e:
|
||||
await ctx.error(f"写入文件时发生未知错误: {e}")
|
||||
return f"错误: {e}"
|
||||
|
||||
|
||||
# --- 运行服务器 ---
|
||||
if __name__ == "__main__":
|
||||
mcp.run(transport="streamable-http")
|
||||
464
mcp/softBV.py
Normal file
464
mcp/softBV.py
Normal file
@@ -0,0 +1,464 @@
|
||||
# softbv_mcp.py
|
||||
#
|
||||
# 在远程服务器上激活 softBV 环境并执行计算(支持 --md 与 --gen-cube 两个专用工具)。
|
||||
# - 生命周期:建立 SSH 连接并注入上下文
|
||||
# - 激活环境:source /cluster/home/koko125/script/softBV.sh
|
||||
# - 工作目录:/cluster/home/koko125/sandbox
|
||||
# - 可执行文件:/cluster/home/koko125/tool/softBV-GUI_linux/bin/softBV.x
|
||||
#
|
||||
# 依赖:
|
||||
# pip install "mcp[cli]" asyncssh pydantic
|
||||
#
|
||||
# 用法(Starlette 挂载示例见你现有 main.py,导入 create_softbv_mcp 即可):
|
||||
# from softbv_mcp import create_softbv_mcp
|
||||
# softbv_mcp = create_softbv_mcp()
|
||||
# Mount("/softbv", app=softbv_mcp.streamable_http_app())
|
||||
#
|
||||
# 可通过环境变量覆盖连接信息:
|
||||
# REMOTE_HOST, REMOTE_USER, PRIVATE_KEY_PATH, SOFTBV_PROFILE, SOFTBV_BIN, DEFAULT_WORKDIR
|
||||
|
||||
import os
|
||||
import posixpath
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from socket import socket
|
||||
from typing import Any
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import asyncssh
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from mcp.server.fastmcp import FastMCP, Context
|
||||
from mcp.server.session import ServerSession
|
||||
|
||||
from lifespan import REMOTE_HOST, REMOTE_USER, PRIVATE_KEY_PATH
|
||||
|
||||
def shell_quote(arg: str) -> str:
|
||||
"""安全地把字符串作为单个 shell 参数(POSIX)。"""
|
||||
return "'" + str(arg).replace("'", "'\"'\"'") + "'"
|
||||
|
||||
# 如果你已经定义了这些常量与路径,可保持复用;也可按需改为你自己的配置源
|
||||
|
||||
|
||||
# 固定的 softBV 环境信息(可用环境变量覆盖)
|
||||
SOFTBV_PROFILE = os.getenv("SOFTBV_PROFILE", "/cluster/home/koko125/script/softBV.sh")
|
||||
SOFTBV_BIN = os.getenv("SOFTBV_BIN", "/cluster/home/koko125/tool/softBV-GUI_linux/bin/softBV.x")
|
||||
DEFAULT_WORKDIR = os.getenv("DEFAULT_WORKDIR", "/cluster/home/koko125/sandbox")
|
||||
|
||||
@dataclass
|
||||
class SoftBVContext:
|
||||
ssh_connection: asyncssh.SSHClientConnection
|
||||
workdir: str
|
||||
profile: str
|
||||
bin_path: str
|
||||
|
||||
@asynccontextmanager
|
||||
async def softbv_lifespan(_server) -> AsyncIterator[SoftBVContext]:
|
||||
"""
|
||||
FastMCP 生命周期:建立 SSH 连接并注入 softBV 上下文。
|
||||
- 不再做额外的 DNS 解析或自定义异步步骤,避免 socket.SOCK_STREAM 的环境覆盖问题
|
||||
- 仅负责连接与清理;工具中通过 ctx.request_context.lifespan_context 访问该上下文 [1]
|
||||
"""
|
||||
# 允许用环境变量覆盖连接信息
|
||||
host = os.getenv("REMOTE_HOST", REMOTE_HOST)
|
||||
user = os.getenv("REMOTE_USER", REMOTE_USER)
|
||||
key_path = os.getenv("PRIVATE_KEY_PATH", PRIVATE_KEY_PATH)
|
||||
|
||||
conn: asyncssh.SSHClientConnection | None = None
|
||||
try:
|
||||
conn = await asyncssh.connect(
|
||||
host,
|
||||
username=user,
|
||||
client_keys=[key_path],
|
||||
known_hosts=None, # 如需主机指纹校验,可移除此参数
|
||||
connect_timeout=15, # 避免长时间挂起
|
||||
)
|
||||
yield SoftBVContext(
|
||||
ssh_connection=conn,
|
||||
workdir=DEFAULT_WORKDIR,
|
||||
profile=SOFTBV_PROFILE,
|
||||
bin_path=SOFTBV_BIN,
|
||||
)
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
await conn.wait_closed()
|
||||
async def run_in_softbv_env(
|
||||
conn: asyncssh.SSHClientConnection,
|
||||
profile_path: str,
|
||||
cmd: str,
|
||||
cwd: str | None = None,
|
||||
check: bool = True,
|
||||
) -> asyncssh.SSHCompletedProcess:
|
||||
"""
|
||||
在远端 bash 会话中激活 softBV 环境并执行 cmd。
|
||||
如提供 cwd,则先 cd 到该目录。
|
||||
"""
|
||||
parts = []
|
||||
if cwd:
|
||||
parts.append(f"cd {shell_quote(cwd)}")
|
||||
parts.append(f"source {shell_quote(profile_path)}")
|
||||
parts.append(cmd)
|
||||
composite = "; ".join(parts)
|
||||
full = f"bash -lc {shell_quote(composite)}"
|
||||
return await conn.run(full, check=check)
|
||||
|
||||
# ===== MCP 服务器 =====
|
||||
mcp = FastMCP(
|
||||
name="softBV Tools",
|
||||
instructions="在远程服务器上激活 softBV 环境并执行相关计算的工具集。",
|
||||
lifespan=softbv_lifespan,
|
||||
streamable_http_path="/",
|
||||
)
|
||||
|
||||
# ===== 辅助:列目录用于识别新生成文件 =====
|
||||
async def _listdir(conn: asyncssh.SSHClientConnection, path: str) -> list[str]:
|
||||
async with conn.start_sftp_client() as sftp:
|
||||
try:
|
||||
return await sftp.listdir(path)
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
# ===== 结构化输入模型:--md =====
|
||||
class SoftBVMDArgs(BaseModel):
|
||||
input_cif: str = Field(description="远程 CIF 文件路径(相对或绝对),作为 --md 的输入")
|
||||
# 位置参数(按帮助中的顺序;None 表示不提供,让程序使用默认)
|
||||
type: str | None = Field(default=None, description="conducting ion 类型(例如 'Li')")
|
||||
os: int | None = Field(default=None, description="conducting ion 氧化态(整数)")
|
||||
sf: float | None = Field(default=None, description="screening factor(非正值使用默认)")
|
||||
temperature: float | None = Field(default=None, description="温度 K(非正值默认 300)")
|
||||
t_end: float | None = Field(default=None, description="生产时间 ps(非正值默认 10.0)")
|
||||
t_equil: float | None = Field(default=None, description="平衡时间 ps(非正值默认 2.0)")
|
||||
dt: float | None = Field(default=None, description="时间步长 ps(非正值默认 0.001)")
|
||||
t_log: float | None = Field(default=None, description="采样间隔 ps(非正值每 100 步)")
|
||||
cwd: str | None = Field(default=None, description="远程工作目录(默认使用生命周期中的 workdir)")
|
||||
|
||||
def _build_md_cmd(bin_path: str, args: SoftBVMDArgs, workdir: str) -> str:
|
||||
input_abs = args.input_cif
|
||||
if not input_abs.startswith("/"):
|
||||
input_abs = posixpath.normpath(posixpath.join(workdir, args.input_cif))
|
||||
parts: list[str] = [shell_quote(bin_path), shell_quote("--md"), shell_quote(input_abs)]
|
||||
for val in [args.type, args.os, args.sf, args.temperature, args.t_end, args.t_equil, args.dt, args.t_log]:
|
||||
if val is not None:
|
||||
parts.append(shell_quote(str(val)))
|
||||
return " ".join(parts)
|
||||
|
||||
# ===== 结构化输入模型:--gen-cube =====
|
||||
class SoftBVGenCubeArgs(BaseModel):
|
||||
input_cif: str = Field(description="远程 CIF 文件路径(相对或绝对),作为 --gen-cube 的输入")
|
||||
type: str | None = Field(default=None, description="conducting ion 类型(如 'Li')")
|
||||
os: int | None = Field(default=None, description="conducting ion 氧化态(整数)")
|
||||
sf: float | None = Field(default=None, description="screening factor(非正值使用默认)")
|
||||
resolution: float | None = Field(default=None, description="体素分辨率(默认约 0.1)")
|
||||
ignore_conducting_ion: bool = Field(default=False, description="flag:ignore_conducting_ion")
|
||||
periodic: bool = Field(default=True, description="flag:periodic(默认 True)")
|
||||
output_name: str | None = Field(default=None, description="输出文件名前缀(可选)")
|
||||
cwd: str | None = Field(default=None, description="远程工作目录(默认使用生命周期中的 workdir)")
|
||||
|
||||
def _build_gen_cube_cmd(bin_path: str, args: SoftBVGenCubeArgs, workdir: str) -> str:
|
||||
input_abs = args.input_cif
|
||||
if not input_abs.startswith("/"):
|
||||
input_abs = posixpath.normpath(posixpath.join(workdir, args.input_cif))
|
||||
parts: list[str] = [shell_quote(bin_path), shell_quote("--gen-cube"), shell_quote(input_abs)]
|
||||
for val in [args.type, args.os, args.sf, args.resolution]:
|
||||
if val is not None:
|
||||
parts.append(shell_quote(str(val)))
|
||||
if args.ignore_conducting_ion:
|
||||
parts.append(shell_quote("--flag:ignore_conducting_ion"))
|
||||
if args.periodic:
|
||||
parts.append(shell_quote("--flag:periodic"))
|
||||
if args.output_name:
|
||||
parts.append(shell_quote(args.output_name))
|
||||
return " ".join(parts)
|
||||
|
||||
# ===== 工具:环境信息检查 =====
|
||||
# 工具:环境信息检查(修复版,避免超时)
|
||||
@mcp.tool()
|
||||
async def softbv_info(ctx: Context[ServerSession, SoftBVContext]) -> dict[str, Any]:
|
||||
"""
|
||||
快速自检:
|
||||
- SFTP 检查工作目录、激活脚本、二进制是否存在/可执行(无需运行 softBV.x)
|
||||
- 激活环境后仅输出标记与当前工作目录,避免长输出或阻塞
|
||||
"""
|
||||
app = ctx.request_context.lifespan_context
|
||||
conn = app.ssh_connection
|
||||
|
||||
# 1) 通过 SFTP 快速检查文件与目录状态(不会长时间阻塞)
|
||||
def stat_path_safe(path: str) -> dict[str, Any]:
|
||||
return {"exists": False, "is_exec": False, "size": None}
|
||||
|
||||
workdir_info = stat_path_safe(app.workdir)
|
||||
profile_info = stat_path_safe(app.profile)
|
||||
bin_info = stat_path_safe(app.bin_path)
|
||||
|
||||
try:
|
||||
async with conn.start_sftp_client() as sftp:
|
||||
# workdir
|
||||
try:
|
||||
attrs = await sftp.stat(app.workdir)
|
||||
workdir_info["exists"] = True
|
||||
workdir_info["size"] = int(attrs.size or 0)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# profile
|
||||
try:
|
||||
attrs = await sftp.stat(app.profile)
|
||||
profile_info["exists"] = True
|
||||
profile_info["size"] = int(attrs.size or 0)
|
||||
perms = int(attrs.permissions or 0)
|
||||
profile_info["is_exec"] = bool(perms & 0o111)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# bin
|
||||
try:
|
||||
attrs = await sftp.stat(app.bin_path)
|
||||
bin_info["exists"] = True
|
||||
bin_info["size"] = int(attrs.size or 0)
|
||||
perms = int(attrs.permissions or 0)
|
||||
bin_info["is_exec"] = bool(perms & 0o111)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
await ctx.warning(f"SFTP 检查失败: {e}")
|
||||
|
||||
# 2) 激活环境并做极简命令(避免 softBV.x --help 的长输出)
|
||||
# 仅返回当前用户、PWD 与二进制可执行判断;不实际运行 softBV.x
|
||||
cmd = "echo __SOFTBV_READY__ && echo $USER && pwd && (test -x " + shell_quote(app.bin_path) + " && echo __BIN_OK__ || echo __BIN_NOT_EXEC__)"
|
||||
proc = await run_in_softbv_env(conn, app.profile, cmd=cmd, cwd=app.workdir, check=False)
|
||||
|
||||
# 解析输出行
|
||||
lines = proc.stdout.splitlines() if proc.stdout else []
|
||||
ready = "__SOFTBV_READY__" in lines
|
||||
user = None
|
||||
pwd = None
|
||||
bin_ok = "__BIN_OK__" in lines
|
||||
|
||||
# 尝试定位 user/pwd(ready 之后的两行)
|
||||
if ready:
|
||||
idx = lines.index("__SOFTBV_READY__")
|
||||
if len(lines) > idx + 1:
|
||||
user = lines[idx + 1].strip()
|
||||
if len(lines) > idx + 2:
|
||||
pwd = lines[idx + 2].strip()
|
||||
|
||||
result = {
|
||||
"host": os.getenv("REMOTE_HOST", REMOTE_HOST),
|
||||
"user": os.getenv("REMOTE_USER", REMOTE_USER),
|
||||
"workdir": app.workdir,
|
||||
"profile": app.profile,
|
||||
"bin_path": app.bin_path,
|
||||
"sftp_check": {
|
||||
"workdir": workdir_info,
|
||||
"profile": profile_info,
|
||||
"bin": bin_info,
|
||||
},
|
||||
"activate_ready": ready,
|
||||
"pwd": pwd,
|
||||
"bin_is_executable": bin_ok or bin_info["is_exec"],
|
||||
"exit_status": proc.exit_status,
|
||||
"stderr_head": "\n".join(proc.stderr.splitlines()[:10]) if proc.stderr else "",
|
||||
}
|
||||
|
||||
# 友好日志
|
||||
if not ready:
|
||||
await ctx.warning("softBV 环境未就绪(可能 source 脚本路径问题或权限不足)")
|
||||
if not result["bin_is_executable"]:
|
||||
await ctx.warning("softBV 二进制不可执行或不存在,请检查 bin_path 与权限(chmod +x)")
|
||||
if proc.exit_status != 0 and not proc.stderr:
|
||||
await ctx.debug("命令非零退出但无 stderr,可能是某些子测试返回非零导致")
|
||||
|
||||
return result
|
||||
# ===== 工具:--md =====
|
||||
@mcp.tool()
|
||||
async def softbv_md(req: SoftBVMDArgs, ctx: Context[ServerSession, SoftBVContext]) -> dict[str, Any]:
|
||||
"""
|
||||
执行 softBV.x --md,返回结构化结果:
|
||||
- cmd/cwd/exit_status/stdout/stderr
|
||||
- new_files:执行后新增文件列表,便于定位输出
|
||||
"""
|
||||
app = ctx.request_context.lifespan_context
|
||||
conn = app.ssh_connection
|
||||
workdir = req.cwd or app.workdir
|
||||
cmd = _build_md_cmd(app.bin_path, req, workdir)
|
||||
|
||||
await ctx.info(f"softBV --md 执行: {cmd} (cwd={workdir})")
|
||||
pre_list = await _listdir(conn, workdir)
|
||||
|
||||
proc = await run_in_softbv_env(conn, app.profile, cmd=cmd, cwd=workdir, check=False)
|
||||
|
||||
post_list = await _listdir(conn, workdir)
|
||||
new_files = sorted(set(post_list) - set(pre_list))
|
||||
if proc.exit_status == 0:
|
||||
await ctx.debug(f"--md 成功,新文件 {len(new_files)} 个")
|
||||
else:
|
||||
await ctx.warning(f"--md 非零退出: {proc.exit_status}")
|
||||
|
||||
return {
|
||||
"cmd": cmd,
|
||||
"cwd": workdir,
|
||||
"exit_status": proc.exit_status,
|
||||
"stdout": proc.stdout,
|
||||
"stderr": proc.stderr,
|
||||
"new_files": new_files,
|
||||
}
|
||||
|
||||
async def run_in_softbv_env_stream(
|
||||
conn: asyncssh.SSHClientConnection,
|
||||
profile_path: str,
|
||||
cmd: str,
|
||||
cwd: str | None = None,
|
||||
) -> asyncssh.SSHClientProcess:
|
||||
parts = []
|
||||
if cwd:
|
||||
parts.append(f"cd {shell_quote(cwd)}")
|
||||
parts.append(f"source {shell_quote(profile_path)} >/dev/null 2>&1 || true")
|
||||
parts.append(cmd)
|
||||
composite = "; ".join(parts)
|
||||
full = f"bash -lc {shell_quote(composite)}"
|
||||
# 不阻塞,返回进程句柄
|
||||
proc = await conn.create_process(full)
|
||||
return proc
|
||||
|
||||
# 轮询目录,识别新生成文件
|
||||
async def _listdir(conn: asyncssh.SSHClientConnection, path: str) -> list[str]:
|
||||
async with conn.start_sftp_client() as sftp:
|
||||
try:
|
||||
return await sftp.listdir(path)
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
# 轮询日志文件大小(作为心跳/粗略进度依据)
|
||||
async def _stat_size(conn: asyncssh.SSHClientConnection, path: str) -> int | None:
|
||||
async with conn.start_sftp_client() as sftp:
|
||||
try:
|
||||
attrs = await sftp.stat(path)
|
||||
return int(attrs.size or 0)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
# 构造 gen-cube 命令(保持你之前的参数拼接逻辑)
|
||||
def _build_gen_cube_cmd(bin_path: str, args: SoftBVGenCubeArgs, workdir: str, log_path: str | None = None) -> str:
|
||||
input_abs = args.input_cif
|
||||
if not input_abs.startswith("/"):
|
||||
input_abs = posixpath.normpath(posixpath.join(workdir, args.input_cif))
|
||||
parts: list[str] = [shell_quote(bin_path), shell_quote("--gen-cube"), shell_quote(input_abs)]
|
||||
for val in [args.type, args.os, args.sf, args.resolution]:
|
||||
if val is not None:
|
||||
parts.append(shell_quote(str(val)))
|
||||
if args.ignore_conducting_ion:
|
||||
parts.append(shell_quote("--flag:ignore_conducting_ion"))
|
||||
if args.periodic:
|
||||
parts.append(shell_quote("--flag:periodic"))
|
||||
if args.output_name:
|
||||
parts.append(shell_quote(args.output_name))
|
||||
cmd = " ".join(parts)
|
||||
# 将输出重定向到日志,便于轮询
|
||||
if log_path:
|
||||
cmd = f"{cmd} > {shell_quote(log_path)} 2>&1"
|
||||
return cmd
|
||||
|
||||
# 修复版:长时运行的 softbv_gen_cube,带心跳与超时保护
|
||||
@mcp.tool()
|
||||
async def softbv_gen_cube(req: SoftBVGenCubeArgs, ctx: Context[ServerSession, SoftBVContext]) -> dict[str, Any]:
|
||||
"""
|
||||
执行 softBV.x --gen-cube,支持长任务心跳,避免在 <25min 内被客户端强制超时。
|
||||
- 每隔 10s 上报一次进度(心跳),包含已用时/日志大小/新增文件数
|
||||
- 结束后返回 stdout_head(来自日志文件片段)、stderr_head(如有)、exit_status、新增文件
|
||||
"""
|
||||
app = ctx.request_context.lifespan_context
|
||||
conn = app.ssh_connection
|
||||
workdir = req.cwd or app.workdir
|
||||
|
||||
# 预先记录目录内容,用于结束后差集
|
||||
before = await _listdir(conn, workdir)
|
||||
|
||||
# 远端日志文件路径(按时间戳命名)
|
||||
import time
|
||||
log_name = f"softbv_gencube_{int(time.time())}.log"
|
||||
log_path = posixpath.join(workdir, log_name)
|
||||
|
||||
# 启动长任务,不阻塞当前协程
|
||||
cmd = _build_gen_cube_cmd(app.bin_path, req, workdir, log_path=log_path)
|
||||
await ctx.info(f"启动 --gen-cube: {cmd}")
|
||||
proc = await run_in_softbv_env_stream(conn, app.profile, cmd=cmd, cwd=workdir)
|
||||
|
||||
# 心跳循环:直到进程退出
|
||||
start_ts = time.time()
|
||||
heartbeat_sec = 10 # 每 10 秒发送一次心跳
|
||||
max_guard_min = 60 # 保险上限(服务端不主动终止;如客户端有限制可调大)
|
||||
try:
|
||||
while True:
|
||||
# 进程是否已退出
|
||||
if proc.exit_status is not None:
|
||||
break
|
||||
|
||||
# 采集状态:已用时、日志大小、新增文件数
|
||||
elapsed = time.time() - start_ts
|
||||
log_size = await _stat_size(conn, log_path)
|
||||
now_files = await _listdir(conn, workdir)
|
||||
new_files_count = len(set(now_files) - set(before))
|
||||
|
||||
# 这里无法获知真实百分比,使用“心跳/已用时提示”
|
||||
await ctx.report_progress(
|
||||
progress=min(elapsed / (25 * 60), 0.99), # 以 25min 为目标上限做近似刻度
|
||||
total=1.0,
|
||||
message=f"gen-cube 运行中: 已用时 {int(elapsed)}s, 日志 {log_size or 0}B, 新文件 {new_files_count}",
|
||||
)
|
||||
|
||||
# 避免客户端超时:持续心跳即可。[1]
|
||||
await asyncio.sleep(heartbeat_sec)
|
||||
|
||||
# 简易守护:超过 max_guard_min 仍未结束也不强制中断(由远端决定)
|
||||
if elapsed > max_guard_min * 60:
|
||||
await ctx.warning("任务已超过守护上限时间,仍在运行(未强制中断)。如需更长时间,请增大上限。")
|
||||
# 不 break,继续等待,交由远端任务完成
|
||||
finally:
|
||||
# 等待进程真正结束(如果已结束,这里是快速返回)
|
||||
await proc.wait()
|
||||
|
||||
# 结束后采集结果
|
||||
exit_status = proc.exit_status
|
||||
after = await _listdir(conn, workdir)
|
||||
new_files = sorted(set(after) - set(before))
|
||||
|
||||
# 读取日志片段(头/尾),帮助定位输出
|
||||
async with conn.start_sftp_client() as sftp:
|
||||
head = ""
|
||||
tail = ""
|
||||
try:
|
||||
async with sftp.open(log_path, "rb") as f:
|
||||
content = await f.read()
|
||||
text = content.decode("utf-8", errors="replace")
|
||||
lines = text.splitlines()
|
||||
head = "\n".join(lines[:40])
|
||||
tail = "\n".join(lines[-40:])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 输出结构化结果
|
||||
result = {
|
||||
"cmd": cmd,
|
||||
"cwd": workdir,
|
||||
"exit_status": exit_status,
|
||||
"log_file": log_path,
|
||||
"stdout_head": head, # 代替一次性 stdout,避免大输出
|
||||
"stderr_head": "", # 统一日志到文件,stderr_head 可留空
|
||||
"new_files": new_files,
|
||||
"elapsed_sec": int(time.time() - start_ts),
|
||||
}
|
||||
|
||||
if exit_status == 0:
|
||||
await ctx.info(f"gen-cube 完成,用时 {result['elapsed_sec']}s,新文件 {len(new_files)} 个")
|
||||
else:
|
||||
await ctx.warning(f"gen-cube 退出码 {exit_status},请查看日志 {log_path}")
|
||||
|
||||
return result
|
||||
def create_softbv_mcp() -> FastMCP:
|
||||
"""供外部(Starlette)导入的工厂函数。"""
|
||||
return mcp
|
||||
|
||||
if __name__ == "__main__":
|
||||
mcp.run(transport="streamable-http")
|
||||
1659
mcp/softBV_remake.py
Normal file
1659
mcp/softBV_remake.py
Normal file
File diff suppressed because it is too large
Load Diff
407
mcp/system_tools.py
Normal file
407
mcp/system_tools.py
Normal file
@@ -0,0 +1,407 @@
|
||||
# system_tools.py
|
||||
import posixpath
|
||||
import stat as pystat
|
||||
from shlex import shlex
|
||||
from typing import Any
|
||||
from contextlib import asynccontextmanager
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
import asyncssh
|
||||
from mcp.server.fastmcp import FastMCP, Context
|
||||
from mcp.server.session import ServerSession
|
||||
|
||||
# 在 system_tools.py 顶部添加
|
||||
def shell_quote(arg: str) -> str:
|
||||
"""
|
||||
安全地把字符串作为单个 shell 参数:
|
||||
- 外层用单引号包裹
|
||||
- 内部的单引号 ' 替换为 '\'' 序列
|
||||
适用于远端 Linux shell 命令拼接
|
||||
"""
|
||||
return "'" + arg.replace("'", "'\"'\"'") + "'"
|
||||
|
||||
# 复用你的配置与数据类
|
||||
from lifespan import (
|
||||
SharedAppContext,
|
||||
REMOTE_HOST,
|
||||
REMOTE_USER,
|
||||
PRIVATE_KEY_PATH,
|
||||
INITIAL_WORKING_DIRECTORY,
|
||||
)
|
||||
|
||||
# —— 1) 定义 FastMCP 的生命周期,在启动时建立 SSH 连接,关闭时断开 ——
|
||||
@asynccontextmanager
|
||||
async def system_lifespan(_server: FastMCP) -> AsyncIterator[SharedAppContext]:
|
||||
"""
|
||||
FastMCP 生命周期:建立并注入共享的 SSH 连接与沙箱根路径。
|
||||
说明:这是 MCP 自己的生命周期,工具里通过 ctx.request_context.lifespan_context 访问。
|
||||
"""
|
||||
conn: asyncssh.SSHClientConnection | None = None
|
||||
try:
|
||||
# 建立 SSH 连接
|
||||
conn = await asyncssh.connect(
|
||||
REMOTE_HOST,
|
||||
username=REMOTE_USER,
|
||||
client_keys=[PRIVATE_KEY_PATH],
|
||||
)
|
||||
# 将类型安全的共享上下文注入 MCP 生命周期
|
||||
yield SharedAppContext(ssh_connection=conn, sandbox_path=INITIAL_WORKING_DIRECTORY)
|
||||
finally:
|
||||
# 关闭连接
|
||||
if conn:
|
||||
conn.close()
|
||||
await conn.wait_closed()
|
||||
|
||||
|
||||
def create_system_mcp() -> FastMCP:
|
||||
"""创建一个包含系统操作工具的 MCP 实例。"""
|
||||
system_mcp = FastMCP(
|
||||
name="System Tools",
|
||||
instructions="用于在远程服务器上进行基本文件和目录操作的工具集。",
|
||||
streamable_http_path="/",
|
||||
lifespan=system_lifespan, # 关键:把生命周期传给 FastMCP [1]
|
||||
)
|
||||
|
||||
def _safe_join(sandbox_root: str, relative_path: str) -> str:
|
||||
"""
|
||||
将用户提供的相对路径映射到沙箱根目录内的规范化绝对路径。
|
||||
- 统一使用 POSIX 语义(远端 Linux)
|
||||
- 禁止使用以 '/' 开头的绝对路径
|
||||
- 禁止 '..' 越界,确保最终路径仍在沙箱内
|
||||
"""
|
||||
rel = (relative_path or ".").strip()
|
||||
|
||||
# 禁止绝对路径,转为相对
|
||||
if rel.startswith("/"):
|
||||
rel = rel.lstrip("/")
|
||||
|
||||
# 规范化拼接
|
||||
combined = posixpath.normpath(posixpath.join(sandbox_root, rel))
|
||||
|
||||
# 统一尾部斜杠处理,避免边界判断遗漏
|
||||
root_norm = sandbox_root.rstrip("/")
|
||||
|
||||
# 确保仍在沙箱内
|
||||
if combined != root_norm and not combined.startswith(root_norm + "/"):
|
||||
raise ValueError("路径越界:仅允许访问沙箱目录内部")
|
||||
|
||||
# 禁止路径中出现 '..'(进一步加固)
|
||||
parts = [p for p in combined.split("/") if p]
|
||||
if ".." in parts:
|
||||
raise ValueError("非法路径:不允许使用 '..' 跨目录")
|
||||
|
||||
return combined
|
||||
|
||||
# —— 重构后的各工具:统一用类型安全的 ctx 与 app_ctx 访问共享资源 ——
|
||||
|
||||
@system_mcp.tool()
|
||||
async def list_files(ctx: Context[ServerSession, SharedAppContext], path: str = ".") -> list[dict[str, Any]]:
|
||||
"""
|
||||
列出远程沙箱目录中的文件和子目录(结构化输出)。
|
||||
Returns:
|
||||
list[dict]: [{name, path, is_dir, size, permissions, mtime}]
|
||||
"""
|
||||
try:
|
||||
app_ctx = ctx.request_context.lifespan_context
|
||||
conn = app_ctx.ssh_connection
|
||||
sandbox_root = app_ctx.sandbox_path
|
||||
|
||||
target = _safe_join(sandbox_root, path)
|
||||
await ctx.info(f"列出目录:{target}")
|
||||
|
||||
items: list[dict[str, Any]] = []
|
||||
async with conn.start_sftp_client() as sftp:
|
||||
names = await sftp.listdir(target) # list[str]
|
||||
for name in names:
|
||||
item_abs = posixpath.join(target, name)
|
||||
attrs = None
|
||||
try:
|
||||
attrs = await sftp.stat(item_abs)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
perms = int(attrs.permissions or 0) if attrs else 0
|
||||
is_dir = bool(pystat.S_ISDIR(perms))
|
||||
size = int(attrs.size or 0) if attrs and attrs.size is not None else 0
|
||||
mtime = int(attrs.mtime or 0) if attrs and attrs.mtime is not None else 0
|
||||
|
||||
items.append({
|
||||
"name": name,
|
||||
"path": posixpath.join(path.rstrip("/"), name) if path != "/" else name,
|
||||
"is_dir": is_dir,
|
||||
"size": size,
|
||||
"permissions": perms,
|
||||
"mtime": mtime,
|
||||
})
|
||||
|
||||
await ctx.debug(f"目录项数量:{len(items)}")
|
||||
return items
|
||||
|
||||
except FileNotFoundError:
|
||||
msg = f"目录不存在或不可访问:{path}"
|
||||
await ctx.error(msg)
|
||||
return [{"error": msg}]
|
||||
except Exception as e:
|
||||
msg = f"list_files 失败:{e}"
|
||||
await ctx.error(msg)
|
||||
return [{"error": msg}]
|
||||
|
||||
@system_mcp.tool()
|
||||
async def read_file(ctx: Context[ServerSession, SharedAppContext], file_path: str, encoding: str = "utf-8") -> str:
|
||||
"""
|
||||
读取远程沙箱内指定文件内容。
|
||||
"""
|
||||
try:
|
||||
app_ctx = ctx.request_context.lifespan_context
|
||||
conn = app_ctx.ssh_connection
|
||||
sandbox_root = app_ctx.sandbox_path
|
||||
|
||||
target = _safe_join(sandbox_root, file_path)
|
||||
await ctx.info(f"读取文件:{target}")
|
||||
|
||||
async with conn.start_sftp_client() as sftp:
|
||||
async with sftp.open(target, "rb") as f:
|
||||
data = await f.read()
|
||||
try:
|
||||
content = data.decode(encoding)
|
||||
except Exception:
|
||||
content = data.decode(encoding, errors="replace")
|
||||
return content
|
||||
|
||||
except FileNotFoundError:
|
||||
msg = f"读取失败:文件不存在 '{file_path}'"
|
||||
await ctx.error(msg)
|
||||
return msg
|
||||
except Exception as e:
|
||||
msg = f"read_file 失败:{e}"
|
||||
await ctx.error(msg)
|
||||
return msg
|
||||
|
||||
@system_mcp.tool()
|
||||
async def write_file(
|
||||
ctx: Context[ServerSession, SharedAppContext],
|
||||
file_path: str,
|
||||
content: str,
|
||||
encoding: str = "utf-8",
|
||||
create_parents: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
写入远程沙箱文件(默认按需创建父目录)。
|
||||
"""
|
||||
try:
|
||||
app_ctx = ctx.request_context.lifespan_context # 类型化的 SharedAppContext
|
||||
conn = app_ctx.ssh_connection
|
||||
sandbox_root = app_ctx.sandbox_path
|
||||
|
||||
target = _safe_join(sandbox_root, file_path)
|
||||
await ctx.info(f"写入文件:{target}")
|
||||
|
||||
if create_parents:
|
||||
parent = posixpath.dirname(target)
|
||||
if parent and parent != sandbox_root:
|
||||
await conn.run(f"mkdir -p {shell_quote(parent)}", check=True)
|
||||
|
||||
data = content.encode(encoding)
|
||||
async with conn.start_sftp_client() as sftp:
|
||||
async with sftp.open(target, "wb") as f:
|
||||
await f.write(data)
|
||||
|
||||
await ctx.debug(f"写入完成:{len(data)} 字节")
|
||||
return {"path": file_path, "bytes_written": len(data)}
|
||||
except Exception as e:
|
||||
msg = f"write_file 失败:{e}"
|
||||
await ctx.error(msg)
|
||||
return {"error": msg}
|
||||
|
||||
@system_mcp.tool()
|
||||
async def make_dir(ctx: Context[ServerSession, SharedAppContext], path: str, parents: bool = True) -> str:
|
||||
try:
|
||||
app_ctx = ctx.request_context.lifespan_context
|
||||
conn = app_ctx.ssh_connection
|
||||
sandbox_root = app_ctx.sandbox_path
|
||||
|
||||
target = _safe_join(sandbox_root, path)
|
||||
|
||||
if parents:
|
||||
# 单行命令:mkdir -p
|
||||
await conn.run(f"mkdir -p {shell_quote(target)}", check=True)
|
||||
else:
|
||||
async with conn.start_sftp_client() as sftp:
|
||||
await sftp.mkdir(target)
|
||||
|
||||
await ctx.info(f"目录已创建: {target}")
|
||||
return f"目录已创建: {path}"
|
||||
except Exception as e:
|
||||
await ctx.error(f"创建目录失败: {e}")
|
||||
return f"错误: {e}"
|
||||
|
||||
@system_mcp.tool()
|
||||
async def delete_file(ctx: Context[ServerSession, SharedAppContext], file_path: str) -> str:
|
||||
"""
|
||||
删除文件(非目录)。
|
||||
"""
|
||||
try:
|
||||
app_ctx = ctx.request_context.lifespan_context
|
||||
conn = app_ctx.ssh_connection
|
||||
sandbox_root = app_ctx.sandbox_path
|
||||
|
||||
target = _safe_join(sandbox_root, file_path)
|
||||
async with conn.start_sftp_client() as sftp:
|
||||
await sftp.remove(target)
|
||||
|
||||
await ctx.info(f"文件已删除: {target}")
|
||||
return f"文件已删除: {file_path}"
|
||||
except FileNotFoundError:
|
||||
msg = f"删除失败:文件不存在 '{file_path}'"
|
||||
await ctx.error(msg)
|
||||
return msg
|
||||
except Exception as e:
|
||||
await ctx.error(f"删除文件失败: {e}")
|
||||
return f"错误: {e}"
|
||||
|
||||
@system_mcp.tool()
|
||||
async def delete_dir(
|
||||
ctx: Context[ServerSession, SharedAppContext],
|
||||
dir_path: str,
|
||||
recursive: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
删除目录。
|
||||
"""
|
||||
try:
|
||||
app_ctx = ctx.request_context.lifespan_context # 类型化的 SharedAppContext
|
||||
conn = app_ctx.ssh_connection
|
||||
sandbox_root = app_ctx.sandbox_path
|
||||
|
||||
target = _safe_join(sandbox_root, dir_path)
|
||||
|
||||
if recursive:
|
||||
await conn.run(f"rm -rf {shell_quote(target)}", check=True)
|
||||
else:
|
||||
async with conn.start_sftp_client() as sftp:
|
||||
await sftp.rmdir(target)
|
||||
|
||||
await ctx.info(f"目录已删除: {target}")
|
||||
return f"目录已删除: {dir_path}"
|
||||
except FileNotFoundError:
|
||||
msg = f"删除失败:目录不存在 '{dir_path}' 或非空"
|
||||
await ctx.error(msg)
|
||||
return msg
|
||||
except Exception as e:
|
||||
await ctx.error(f"删除目录失败: {e}")
|
||||
return f"错误: {e}"
|
||||
|
||||
@system_mcp.tool()
|
||||
async def move_path(ctx: Context[ServerSession, SharedAppContext], src: str, dst: str,
|
||||
overwrite: bool = True) -> str:
|
||||
try:
|
||||
app_ctx = ctx.request_context.lifespan_context
|
||||
conn = app_ctx.ssh_connection
|
||||
sandbox_root = app_ctx.sandbox_path
|
||||
|
||||
src_abs = _safe_join(sandbox_root, src)
|
||||
dst_abs = _safe_join(sandbox_root, dst)
|
||||
|
||||
flag = "-f" if overwrite else ""
|
||||
# 单行命令:mv
|
||||
cmd = f"mv {flag} {shell_quote(src_abs)} {shell_quote(dst_abs)}".strip()
|
||||
await conn.run(cmd, check=True)
|
||||
|
||||
await ctx.info(f"已移动: {src_abs} -> {dst_abs}")
|
||||
return f"已移动: {src} -> {dst}"
|
||||
except FileNotFoundError:
|
||||
msg = f"移动失败:源不存在 '{src}'"
|
||||
await ctx.error(msg)
|
||||
return msg
|
||||
except Exception as e:
|
||||
await ctx.error(f"移动失败: {e}")
|
||||
return f"错误: {e}"
|
||||
|
||||
@system_mcp.tool()
|
||||
async def copy_path(
|
||||
ctx: Context[ServerSession, SharedAppContext],
|
||||
src: str,
|
||||
dst: str,
|
||||
recursive: bool = True,
|
||||
overwrite: bool = True,
|
||||
) -> str:
|
||||
try:
|
||||
app_ctx = ctx.request_context.lifespan_context
|
||||
conn = app_ctx.ssh_connection
|
||||
sandbox_root = app_ctx.sandbox_path
|
||||
|
||||
src_abs = _safe_join(sandbox_root, src)
|
||||
dst_abs = _safe_join(sandbox_root, dst)
|
||||
|
||||
flags = []
|
||||
if recursive:
|
||||
flags.append("-r")
|
||||
if overwrite:
|
||||
flags.append("-f")
|
||||
|
||||
# 单行命令:cp
|
||||
cmd = " ".join(["cp"] + flags + [shell_quote(src_abs), shell_quote(dst_abs)])
|
||||
await conn.run(cmd, check=True)
|
||||
|
||||
await ctx.info(f"已复制: {src_abs} -> {dst_abs}")
|
||||
return f"已复制: {src} -> {dst}"
|
||||
except FileNotFoundError:
|
||||
msg = f"复制失败:源不存在 '{src}'"
|
||||
await ctx.error(msg)
|
||||
return msg
|
||||
except Exception as e:
|
||||
await ctx.error(f"复制失败: {e}")
|
||||
return f"错误: {e}"
|
||||
|
||||
|
||||
@system_mcp.tool()
|
||||
async def exists(ctx: Context[ServerSession, SharedAppContext], path: str) -> bool:
|
||||
"""
|
||||
判断路径(文件/目录)是否存在。
|
||||
"""
|
||||
try:
|
||||
app_ctx = ctx.request_context.lifespan_context
|
||||
conn = app_ctx.ssh_connection
|
||||
sandbox_root = app_ctx.sandbox_path
|
||||
|
||||
target = _safe_join(sandbox_root, path)
|
||||
async with conn.start_sftp_client() as sftp:
|
||||
await sftp.stat(target)
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
except Exception as e:
|
||||
await ctx.error(f"exists 检查失败: {e}")
|
||||
return False
|
||||
|
||||
@system_mcp.tool()
|
||||
async def stat_path(ctx: Context[ServerSession, SharedAppContext], path: str) -> dict:
|
||||
"""
|
||||
查看远程路径属性(结构化输出)。
|
||||
"""
|
||||
try:
|
||||
app_ctx = ctx.request_context.lifespan_context
|
||||
conn = app_ctx.ssh_connection
|
||||
sandbox_root = app_ctx.sandbox_path
|
||||
|
||||
target = _safe_join(sandbox_root, path)
|
||||
async with conn.start_sftp_client() as sftp:
|
||||
attrs = await sftp.stat(target)
|
||||
|
||||
perms = attrs.permissions or 0
|
||||
return {
|
||||
"path": target,
|
||||
"size": int(attrs.size or 0),
|
||||
"is_dir": bool(pystat.S_ISDIR(perms)),
|
||||
"permissions": perms, # 九进制权限位,例:0o755
|
||||
"mtime": int(attrs.mtime or 0), # 秒
|
||||
}
|
||||
except FileNotFoundError:
|
||||
msg = f"路径不存在: {path}"
|
||||
await ctx.error(msg)
|
||||
return {"error": msg}
|
||||
except Exception as e:
|
||||
await ctx.error(f"stat 失败: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
return system_mcp
|
||||
19
mcp/test_tools.py
Normal file
19
mcp/test_tools.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# test_tools.py
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
|
||||
def create_test_mcp() -> FastMCP:
|
||||
"""创建一个只包含最简单工具的 MCP 实例,用于测试连接。"""
|
||||
|
||||
test_mcp = FastMCP(
|
||||
name="Test Tools",
|
||||
instructions="用于测试服务器连接是否通畅。",
|
||||
streamable_http_path="/"
|
||||
)
|
||||
|
||||
@test_mcp.tool()
|
||||
def ping() -> str:
|
||||
"""一个简单的工具,用于确认服务器是否响应。"""
|
||||
return "pong"
|
||||
|
||||
return test_mcp
|
||||
185
mcp/topological_analysis_models.py
Normal file
185
mcp/topological_analysis_models.py
Normal file
@@ -0,0 +1,185 @@
|
||||
# topological_analysis_tools.py
|
||||
|
||||
import posixpath
|
||||
import re
|
||||
from typing import Any, Literal
|
||||
from contextlib import asynccontextmanager
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
import asyncssh
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# 从 MCP SDK 导入核心组件 [1]
|
||||
from mcp.server.fastmcp import FastMCP, Context
|
||||
from mcp.server.session import ServerSession
|
||||
|
||||
# 从 lifespan.py 导入共享的配置和数据类 [3]
|
||||
# 注意: 我们不再需要导入 system_lifespan
|
||||
from lifespan import (
|
||||
SharedAppContext,
|
||||
REMOTE_HOST,
|
||||
REMOTE_USER,
|
||||
PRIVATE_KEY_PATH,
|
||||
INITIAL_WORKING_DIRECTORY,
|
||||
)
|
||||
|
||||
# ==============================================================================
|
||||
# 1. 辅助函数 (为实现独立性而复制)
|
||||
# ==============================================================================
|
||||
|
||||
def shell_quote(arg: str) -> str:
|
||||
"""安全地将字符串转义为单个 shell 参数 [5]。"""
|
||||
return "'" + arg.replace("'", "'\"'\"'") + "'"
|
||||
|
||||
def _safe_join(sandbox_root: str, relative_path: str) -> str:
|
||||
"""安全地将用户提供的相对路径连接到沙箱根目录 [5]。"""
|
||||
rel = (relative_path or ".").strip()
|
||||
if rel.startswith("/"):
|
||||
rel = rel.lstrip("/")
|
||||
combined = posixpath.normpath(posixpath.join(sandbox_root, rel))
|
||||
root_norm = sandbox_root.rstrip("/")
|
||||
if combined != root_norm and not combined.startswith(root_norm + "/"):
|
||||
raise ValueError("路径越界: 仅允许访问沙箱目录内部")
|
||||
if ".." in combined.split("/"):
|
||||
raise ValueError("非法路径: 不允许使用 '..' 跨目录")
|
||||
return combined
|
||||
|
||||
# ==============================================================================
|
||||
# 2. 独立的生命周期管理器
|
||||
# ==============================================================================
|
||||
|
||||
@asynccontextmanager
|
||||
async def analysis_lifespan(_server: FastMCP) -> AsyncIterator[SharedAppContext]:
|
||||
"""
|
||||
为拓扑分析工具独立管理 SSH 连接的生命周期。
|
||||
"""
|
||||
conn: asyncssh.SSHClientConnection | None = None
|
||||
print("分析工具: 正在建立独立的 SSH 连接...")
|
||||
try:
|
||||
conn = await asyncssh.connect(
|
||||
REMOTE_HOST,
|
||||
username=REMOTE_USER,
|
||||
client_keys=[PRIVATE_KEY_PATH],
|
||||
)
|
||||
print(f"分析工具: 独立的 SSH 连接到 {REMOTE_HOST} 成功!")
|
||||
yield SharedAppContext(ssh_connection=conn, sandbox_path=INITIAL_WORKING_DIRECTORY)
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
await conn.wait_closed()
|
||||
print("分析工具: 独立的 SSH 连接已关闭。")
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# 3. 数据模型与解析函数 (与之前相同)
|
||||
# ==============================================================================
|
||||
# (为简洁起见,此处省略了 Pydantic 模型和 _parse_analysis_output 函数的定义,
|
||||
# 请将上一版本中的这部分代码粘贴到这里)
|
||||
VALID_CONFIGS = Literal["O.yaml", "S.yaml", "Cl.yaml", "Br.yaml"]
|
||||
class SubmitTopologicalAnalysisParams(BaseModel):
|
||||
cif_file_path: str = Field(description="远程服务器上 CIF 文件的相对路径。")
|
||||
config_files: list[VALID_CONFIGS] = Field(description="选择一个或多个用于分析的 YAML 配置文件。")
|
||||
output_file_path: str = Field(description="用于保存计算结果的输出文件相对路径。")
|
||||
class AnalysisResult(BaseModel):
|
||||
percolation_diameter_a: float | None = Field(None, description="渗透直径 (Percolation diameter), 单位 Å。")
|
||||
connectivity_distance: float | None = Field(None, description="连通距离 (the minium of d), 单位 Å。")
|
||||
max_node_length_a: float | None = Field(None, description="最大节点长度 (Maximum node length), 单位 Å。")
|
||||
min_cation_distance_a: float | None = Field(None, description="到阳离子的最小距离, 单位 Å。")
|
||||
total_time_s: float | None = Field(None, description="总计算耗时, 单位秒。")
|
||||
long_node_count: int | None = Field(None, description="长节点数量 (Long node number)。")
|
||||
short_node_count: int | None = Field(None, description="短节点数量 (Short node number)。")
|
||||
warnings: list[str] = Field(default_factory=list, description="解析过程中遇到的警告或缺失的关键值信息。")
|
||||
def _parse_analysis_output(log_content: str) -> AnalysisResult:
|
||||
def find_float(pattern: str) -> float | None:
|
||||
match = re.search(pattern, log_content)
|
||||
return float(match.group(1)) if match else None
|
||||
def find_int(pattern: str) -> int | None:
|
||||
match = re.search(pattern, log_content)
|
||||
return int(match.group(1)) if match else None
|
||||
data = { "percolation_diameter_a": find_float(r"Percolation diameter \(A\):\s*([\d\.]+)"), "max_node_length_a": find_float(r"Maximum node length detected:\s*([\d\.]+)\s*A"), "min_cation_distance_a": find_float(r"minimum distance to cations is\s*([\d\.]+)\s*A"), "total_time_s": find_float(r"Total used time:\s*([\d\.]+)"), "long_node_count": find_int(r"Long node number:\s*(\d+)"), "short_node_count": find_int(r"Short node number:\s*(\d+)"), }
|
||||
conn_dist_match = re.search(r"the minium of d\s*([\d\.]+)", log_content, re.MULTILINE)
|
||||
data["connectivity_distance"] = float(conn_dist_match.group(1)) if conn_dist_match else None
|
||||
warnings = [k + " 未找到" for k, v in data.items() if v is None and k not in ["min_cation_distance_a", "long_node_count", "short_node_count"]]
|
||||
return AnalysisResult(**data, warnings=warnings)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# 4. MCP 服务器工厂函数
|
||||
# ==============================================================================
|
||||
|
||||
def create_topological_analysis_mcp() -> FastMCP:
|
||||
"""
|
||||
供 Starlette 挂载的工厂函数。
|
||||
这个函数现在创建并使用自己独立的 lifespan 管理器。
|
||||
"""
|
||||
analysis_mcp = FastMCP(
|
||||
name="Topological Analysis Tools",
|
||||
instructions="用于在远程服务器上提交和分析拓扑计算任务的工具集。",
|
||||
streamable_http_path="/",
|
||||
lifespan=analysis_lifespan, # 关键: 使用本文件内定义的独立 lifespan
|
||||
)
|
||||
|
||||
@analysis_mcp.tool()
|
||||
async def submit_topological_analysis(
|
||||
params: SubmitTopologicalAnalysisParams,
|
||||
ctx: Context[ServerSession, SharedAppContext],
|
||||
) -> dict[str, Any]:
|
||||
"""【步骤1/2】异步提交拓扑分析任务。"""
|
||||
# ... (工具的实现逻辑与之前完全相同)
|
||||
try:
|
||||
app_ctx = ctx.request_context.lifespan_context
|
||||
conn = app_ctx.ssh_connection
|
||||
sandbox_root = app_ctx.sandbox_path
|
||||
home_dir = f"/cluster/home/{REMOTE_USER}"
|
||||
tool_dir = f"{home_dir}/tool/Topological_analyse"
|
||||
cif_abs_path = _safe_join(sandbox_root, params.cif_file_path)
|
||||
output_abs_path = _safe_join(sandbox_root, params.output_file_path)
|
||||
config_args = " ".join([shell_quote(f"{tool_dir}/{cfg}") for cfg in params.config_files])
|
||||
config_flags = f"-i {config_args}"
|
||||
command = f"""
|
||||
nohup sh -c '
|
||||
source {shell_quote(f"{home_dir}/.bashrc")} && \\
|
||||
conda activate {shell_quote(f"{home_dir}/anaconda3/envs/zeo")} && \\
|
||||
python {shell_quote(f"{tool_dir}/analyze_voronoi_nodes.py")} \\
|
||||
{shell_quote(cif_abs_path)} \\
|
||||
{config_flags}
|
||||
' > {shell_quote(output_abs_path)} 2>&1 &
|
||||
""".strip()
|
||||
await ctx.info("提交后台拓扑分析任务...")
|
||||
await conn.run(command, check=True)
|
||||
return {"success": True, "message": "拓扑分析任务已成功提交。", "output_file": params.output_file_path}
|
||||
except Exception as e:
|
||||
msg = f"提交后台任务失败: {e}"
|
||||
await ctx.error(msg)
|
||||
return {"success": False, "error": msg}
|
||||
|
||||
@analysis_mcp.tool()
|
||||
async def analyze_topological_results(
|
||||
output_file_path: str,
|
||||
ctx: Context[ServerSession, SharedAppContext],
|
||||
) -> AnalysisResult:
|
||||
"""【步骤2/2】读取并分析拓扑分析任务的输出文件。"""
|
||||
# ... (工具的实现逻辑与之前完全相同)
|
||||
try:
|
||||
await ctx.info(f"开始分析结果文件: {output_file_path}")
|
||||
app_ctx = ctx.request_context.lifespan_context
|
||||
conn = app_ctx.ssh_connection
|
||||
sandbox_root = app_ctx.sandbox_path
|
||||
target_path = _safe_join(sandbox_root, output_file_path)
|
||||
async with conn.start_sftp_client() as sftp:
|
||||
async with sftp.open(target_path, "r") as f:
|
||||
log_content = await f.read()
|
||||
if not log_content:
|
||||
raise ValueError("结果文件为空。")
|
||||
analysis_data = _parse_analysis_output(log_content)
|
||||
return analysis_data
|
||||
except FileNotFoundError:
|
||||
msg = f"分析失败: 结果文件 '{output_file_path}' 不存在。"
|
||||
await ctx.error(msg)
|
||||
return AnalysisResult(warnings=[msg])
|
||||
except Exception as e:
|
||||
msg = f"分析结果时发生错误: {e}"
|
||||
await ctx.error(msg)
|
||||
return AnalysisResult(warnings=[msg])
|
||||
|
||||
return analysis_mcp
|
||||
10
mcp/uvicorn/main.py
Normal file
10
mcp/uvicorn/main.py
Normal file
@@ -0,0 +1,10 @@
|
||||
|
||||
# main.py
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/")
|
||||
async def read_root():
|
||||
return {"message": "Hello, World!"}
|
||||
19
softBV/out/Br.csv
Normal file
19
softBV/out/Br.csv
Normal file
@@ -0,0 +1,19 @@
|
||||
filename,formula,conductivity(e-3S/m)
|
||||
1080005.cif,Cs3Li2Br5,21.23
|
||||
1147619.cif,Li3YBr6,20.00
|
||||
1147621.cif,Li3InBr6,7.23
|
||||
1211043.cif,LiFeBr4,11.60
|
||||
1222492.cif,Li3ErBr6,29.61
|
||||
1222679.cif,Li2MnBr4,9.68
|
||||
2033990.cif,CsLi2Br3,9.71
|
||||
22967.cif,Li2FeBr4,11.21
|
||||
23057.cif,CsLiBr2,5.12
|
||||
2763849.cif,Cs3Li2Br5,121.34
|
||||
28237.cif,RbLiBr2,3.46
|
||||
28250.cif,Li2MnBr4,12.25
|
||||
28326.cif,LiGaBr4,1.56
|
||||
28327.cif,LiGaBr3,8.32
|
||||
28829.cif,Li2ZnBr4,15.71
|
||||
37873.cif,Li3ErBr6,10.90
|
||||
580554.cif,CsLi3Br4,38.40
|
||||
606680.cif,CsLi2Br3,20.63
|
||||
|
165
softBV/run_softBV.py
Normal file
165
softBV/run_softBV.py
Normal file
@@ -0,0 +1,165 @@
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
import re
|
||||
import csv
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
# --- 配置区 ---
|
||||
# softBV.x 可执行文件的路径,请根据你的实际情况修改
|
||||
SOFTBV_EXECUTABLE = Path.home() / "tool/softBV-GUI_linux/bin/softBV.x"
|
||||
|
||||
# 固定的命令参数
|
||||
MOBILE_ION = "Li"
|
||||
ION_VALENCE = "1"
|
||||
|
||||
# 输出CSV文件的名称
|
||||
OUTPUT_CSV_FILE = "conductivity_results.csv"
|
||||
|
||||
|
||||
# --- 配置区结束 ---
|
||||
|
||||
def check_executable():
|
||||
"""检查 softBV.x 文件是否存在且可执行"""
|
||||
if not SOFTBV_EXECUTABLE.is_file():
|
||||
print(f"错误: 可执行文件未找到: {SOFTBV_EXECUTABLE}")
|
||||
print("请检查脚本中的 SOFTBV_EXECUTABLE 路径是否正确。")
|
||||
sys.exit(1)
|
||||
if not os.access(SOFTBV_EXECUTABLE, os.X_OK):
|
||||
print(f"错误: 文件存在但不可执行: {SOFTBV_EXECUTABLE}")
|
||||
print("请使用 'chmod +x' 命令赋予其执行权限。")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def get_formula_from_cif_line2(cif_path):
|
||||
"""
|
||||
通过读取CIF文件的第二行来提取化学式。
|
||||
期望的格式是 'data_FORMULA'。
|
||||
"""
|
||||
try:
|
||||
with open(cif_path, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
if len(lines) > 1:
|
||||
second_line = lines[1].strip() # .strip() 去除首尾空白和换行符
|
||||
if second_line.lower().startswith('data_'):
|
||||
# 提取 'data_' 之后的内容作为化学式
|
||||
return second_line[5:]
|
||||
else:
|
||||
print(f"\n警告: 文件 {cif_path.name} 的第二行格式不正确 (不是以'data_'开头)。", file=sys.stderr)
|
||||
return "FormatError"
|
||||
else:
|
||||
print(f"\n警告: 文件 {cif_path.name} 行数不足2行。", file=sys.stderr)
|
||||
return "FileTooShort"
|
||||
except Exception as e:
|
||||
print(f"\n警告: 读取文件 {cif_path.name} 时发生错误: {e}", file=sys.stderr)
|
||||
return "ReadError"
|
||||
|
||||
|
||||
def parse_conductivity(output_text):
|
||||
"""从 softBV.x 的输出文本中解析并格式化电导率"""
|
||||
pattern = re.compile(r"^\s*MD: conductivity\s*=\s*([-\d.eE+]+)", re.MULTILINE)
|
||||
match = pattern.search(output_text)
|
||||
|
||||
if match:
|
||||
conductivity_str = match.group(1)
|
||||
try:
|
||||
conductivity_val = float(conductivity_str)
|
||||
# 将单位从 S/m 转换为 1e-3 S/m (mS/m), 需要乘以 1000
|
||||
conductivity_ms_m = conductivity_val * 1000
|
||||
# 格式化为两位小数的字符串
|
||||
return f"{conductivity_ms_m:.2f}"
|
||||
except ValueError:
|
||||
return "ValueError"
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="运行 softBV.x 计算电导率并汇总到CSV文件。",
|
||||
formatter_class=argparse.RawTextHelpFormatter
|
||||
)
|
||||
parser.add_argument("folder", type=str, help="包含CIF文件的目标文件夹路径。")
|
||||
parser.add_argument(
|
||||
"-t", "--temperature",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="模拟温度 (K),对应 softBV.x 的最后一个参数。\n默认值: 1000"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
target_folder = Path(args.folder)
|
||||
temperature = str(args.temperature)
|
||||
|
||||
if not target_folder.is_dir():
|
||||
print(f"错误: 文件夹 '{target_folder}' 不存在。")
|
||||
sys.exit(1)
|
||||
|
||||
check_executable()
|
||||
|
||||
cif_files = sorted(list(target_folder.glob("*.cif")))
|
||||
|
||||
if not cif_files:
|
||||
print(f"警告: 在文件夹 '{target_folder}' 中没有找到任何 .cif 文件。")
|
||||
return
|
||||
|
||||
print(f"在 '{target_folder}' 中找到 {len(cif_files)} 个 .cif 文件,开始处理...")
|
||||
print(f"模拟温度设置为: {temperature} K")
|
||||
|
||||
results = []
|
||||
|
||||
for cif_file in cif_files:
|
||||
print(f" -> 正在处理: {cif_file} ...", end="", flush=True)
|
||||
|
||||
# 使用新的函数获取化学式
|
||||
formula = get_formula_from_cif_line2(cif_file)
|
||||
|
||||
command = [
|
||||
str(SOFTBV_EXECUTABLE), "--md", str(cif_file),
|
||||
MOBILE_ION, ION_VALENCE, temperature
|
||||
]
|
||||
|
||||
try:
|
||||
process_result = subprocess.run(
|
||||
command, capture_output=True, text=True, check=True, timeout=300
|
||||
)
|
||||
conductivity = parse_conductivity(process_result.stdout)
|
||||
|
||||
if conductivity is not None:
|
||||
results.append([cif_file.name, formula, conductivity])
|
||||
print(f" 成功, Formula: {formula}, Conductivity: {conductivity} (x10^-3 S/m)")
|
||||
else:
|
||||
print(" 失败 (无法在输出中找到conductivity)")
|
||||
results.append([cif_file.name, formula, "NotFound"])
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f" 失败 (命令执行错误)")
|
||||
print(f" 错误信息: {e.stderr.strip()}", file=sys.stderr)
|
||||
results.append([cif_file.name, formula, "ExecError"])
|
||||
except subprocess.TimeoutExpired:
|
||||
print(f" 失败 (命令执行超时)")
|
||||
results.append([cif_file.name, formula, "Timeout"])
|
||||
except Exception as e:
|
||||
print(f" 失败 (发生未知错误: {e})", file=sys.stderr)
|
||||
results.append([cif_file.name, formula, "UnknownError"])
|
||||
|
||||
if not results:
|
||||
print("没有成功处理任何文件,不生成CSV。")
|
||||
return
|
||||
|
||||
print(f"\n处理完成,正在将 {len(results)} 条结果写入 {OUTPUT_CSV_FILE} ...")
|
||||
|
||||
try:
|
||||
with open(OUTPUT_CSV_FILE, 'w', newline='', encoding='utf-8') as csvfile:
|
||||
writer = csv.writer(csvfile)
|
||||
writer.writerow(['filename', 'formula', 'conductivity(e-3S/m)'])
|
||||
writer.writerows(results)
|
||||
print("CSV文件已成功生成!")
|
||||
except IOError as e:
|
||||
print(f"错误: 无法写入CSV文件: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user