13 Commits

Author SHA1 Message Date
koko
f78298e803 一阶段高筛制作完成 2025-12-16 11:36:49 +08:00
6ea96c81d6 增加扩胞逻辑 2025-12-14 18:33:54 +08:00
9bde3e1229 增加扩胞逻辑 2025-12-14 18:11:00 +08:00
83647c2218 增加扩胞逻辑 2025-12-14 18:01:11 +08:00
9b36aa10ff 增加扩胞逻辑 2025-12-14 17:57:42 +08:00
2378a3f2a2 增加扩胞逻辑 2025-12-14 16:59:01 +08:00
da8c85b830 增加扩胞逻辑 2025-12-14 16:57:57 +08:00
72cf0a79e1 增加扩胞逻辑 2025-12-14 16:52:14 +08:00
f27fd3e3ce 预处理增加并行计算 2025-12-14 15:53:11 +08:00
1fee324c90 预处理增加并行计算 2025-12-14 15:47:27 +08:00
ae4e7280b4 预处理增加并行计算 2025-12-14 15:42:13 +08:00
c91998662a 重构预处理制作 2025-12-14 15:10:18 +08:00
6eeb40d222 重构预处理制作 2025-12-14 14:34:26 +08:00
38 changed files with 4892 additions and 63 deletions

21
.gitignore vendored Normal file
View File

@@ -0,0 +1,21 @@
# --- Python 通用忽略 ---
__pycache__/
*.pyc
*.pyo
*.pyd
.Python
env/
venv/
.env
.venv/
# --- VS Code 配置 (可选,建议忽略) ---
.vscode/
# --- JetBrains (你之前的配置) ---
.idea/
/shelf/
/workspace.xml
/httpRequests/
/dataSources/
/dataSources.local.xml

8
.idea/.gitignore generated vendored
View File

@@ -1,8 +0,0 @@
# 默认忽略的文件
/shelf/
/workspace.xml
# 基于编辑器的 HTTP 客户端请求
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

8
.idea/Screen.iml generated
View File

@@ -1,8 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="Python 3.11" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>

View File

@@ -1,23 +0,0 @@
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
<option name="ignoredPackages">
<value>
<list size="9">
<item index="0" class="java.lang.String" itemvalue="torchvision" />
<item index="1" class="java.lang.String" itemvalue="torch" />
<item index="2" class="java.lang.String" itemvalue="tqdm" />
<item index="3" class="java.lang.String" itemvalue="scipy" />
<item index="4" class="java.lang.String" itemvalue="h5py" />
<item index="5" class="java.lang.String" itemvalue="matplotlib" />
<item index="6" class="java.lang.String" itemvalue="numpy" />
<item index="7" class="java.lang.String" itemvalue="opencv_python" />
<item index="8" class="java.lang.String" itemvalue="Pillow" />
</list>
</value>
</option>
</inspection_tool>
</profile>
</component>

View File

@@ -1,6 +0,0 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>

8
.idea/modules.xml generated
View File

@@ -1,8 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/Screen.iml" filepath="$PROJECT_DIR$/.idea/Screen.iml" />
</modules>
</component>
</project>

6
.idea/vcs.xml generated
View File

@@ -1,6 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$" vcs="Git" />
</component>
</project>

0
config/settings.yaml Normal file
View File

View File

643
main.py Normal file
View File

@@ -0,0 +1,643 @@
"""
高通量筛选与扩胞项目 - 主入口(支持断点续做)
"""
import os
import sys
import json
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
from src.analysis.database_analyzer import DatabaseAnalyzer
from src.analysis.report_generator import ReportGenerator
from src.core.executor import TaskExecutor
from src.preprocessing.processor import StructureProcessor
from src.computation.workspace_manager import WorkspaceManager
from src.computation.zeo_executor import ZeoExecutor, ZeoConfig
from src.computation.result_processor import ResultProcessor, FilterCriteria
def print_banner():
print("""
╔═══════════════════════════════════════════════════════════════════╗
║ 高通量筛选与扩胞项目 - 数据库分析工具 v2.2 ║
║ 支持断点续做与高性能并行计算 ║
╚═══════════════════════════════════════════════════════════════════╝
""")
def detect_and_show_environment():
"""检测并显示环境信息"""
env = TaskExecutor.detect_environment()
print("【运行环境检测】")
print(f" 主机名: {env['hostname']}")
print(f" 本地CPU核数: {env['total_cores']}")
print(f" SLURM集群: {'✅ 可用' if env['has_slurm'] else '❌ 不可用'}")
if env['has_slurm'] and env['slurm_partitions']:
print(f" 可用分区: {', '.join(env['slurm_partitions'])}")
if env['conda_env']:
print(f" 当前Conda: {env['conda_env']}")
return env
def detect_workflow_status(workspace_path: str = "workspace", target_cation: str = "Li") -> dict:
"""
检测工作流程状态,确定可以从哪一步继续
Returns:
dict: {
'has_processed_data': bool, # 是否有扩胞处理后的数据
'has_zeo_results': bool, # 是否有 Zeo++ 计算结果
'total_structures': int, # 总结构数
'structures_with_log': int, # 有 log.txt 的结构数
'workspace_info': object, # 工作区信息
'ws_manager': object # 工作区管理器
}
"""
status = {
'has_processed_data': False,
'has_zeo_results': False,
'total_structures': 0,
'structures_with_log': 0,
'workspace_info': None,
'ws_manager': None
}
# 创建工作区管理器
ws_manager = WorkspaceManager(
workspace_path=workspace_path,
tool_dir="tool",
target_cation=target_cation
)
status['ws_manager'] = ws_manager
# 检查是否有处理后的数据
data_dir = os.path.join(workspace_path, "data")
if os.path.exists(data_dir):
existing = ws_manager.check_existing_workspace()
if existing and existing.total_structures > 0:
status['has_processed_data'] = True
status['total_structures'] = existing.total_structures
status['workspace_info'] = existing
# 检查有多少结构有 log.txt即已完成 Zeo++ 计算)
log_count = 0
for anion_key in os.listdir(data_dir):
anion_dir = os.path.join(data_dir, anion_key)
if not os.path.isdir(anion_dir):
continue
for struct_name in os.listdir(anion_dir):
struct_dir = os.path.join(anion_dir, struct_name)
if os.path.isdir(struct_dir):
log_path = os.path.join(struct_dir, "log.txt")
if os.path.exists(log_path):
log_count += 1
status['structures_with_log'] = log_count
# 如果大部分结构都有 log.txt认为 Zeo++ 计算已完成
if log_count > 0 and log_count >= status['total_structures'] * 0.5:
status['has_zeo_results'] = True
return status
def print_workflow_status(status: dict):
"""打印工作流程状态"""
print("\n" + "" * 50)
print("【工作流程状态检测】")
print("" * 50)
if not status['has_processed_data']:
print(" Step 1 (扩胞+化合价): ❌ 未完成")
print(" Step 2-4 (Zeo++ 计算): ❌ 未完成")
print(" Step 5 (结果筛选): ❌ 未完成")
else:
print(f" Step 1 (扩胞+化合价): ✅ 已完成 ({status['total_structures']} 个结构)")
if status['has_zeo_results']:
print(f" Step 2-4 (Zeo++ 计算): ✅ 已完成 ({status['structures_with_log']}/{status['total_structures']} 有日志)")
print(" Step 5 (结果筛选): ⏳ 可执行")
else:
print(f" Step 2-4 (Zeo++ 计算): ⏳ 可执行 ({status['structures_with_log']}/{status['total_structures']} 有日志)")
print(" Step 5 (结果筛选): ❌ 需先完成 Zeo++ 计算")
print("" * 50)
def get_user_choice(status: dict) -> str:
"""
根据工作流程状态获取用户选择
Returns:
'step1': 从头开始(数据库分析 + 扩胞)
'step2': 从 Zeo++ 计算开始
'step5': 直接进行结果筛选
'exit': 退出
"""
print("\n请选择操作:")
options = []
if status['has_zeo_results']:
options.append(('5', '直接进行结果筛选 (Step 5)'))
options.append(('2', '重新运行 Zeo++ 计算 (Step 2-4)'))
options.append(('1', '从头开始 (数据库分析 + 扩胞)'))
elif status['has_processed_data']:
options.append(('2', '运行 Zeo++ 计算 (Step 2-4)'))
options.append(('1', '从头开始 (数据库分析 + 扩胞)'))
else:
options.append(('1', '从头开始 (数据库分析 + 扩胞)'))
options.append(('q', '退出'))
for key, desc in options:
print(f" {key}. {desc}")
choice = input("\n请选择 [默认: " + options[0][0] + "]: ").strip().lower()
if not choice:
choice = options[0][0]
if choice == 'q':
return 'exit'
elif choice == '5':
return 'step5'
elif choice == '2':
return 'step2'
else:
return 'step1'
def run_step1_database_analysis(env: dict, cation: str) -> dict:
"""
Step 1: 数据库分析与扩胞处理
Returns:
处理参数字典,如果用户取消则返回 None
"""
print("\n" + "" * 60)
print("【Step 1: 数据库分析与扩胞处理】")
print("" * 60)
# 数据库路径
while True:
db_path = input("\n📂 请输入数据库路径: ").strip()
if os.path.exists(db_path):
break
print(f"❌ 路径不存在: {db_path}")
# 检测当前Conda环境路径
default_conda = "/cluster/home/koko125/anaconda3/envs/screen"
conda_env_path = env.get('conda_env', '') or default_conda
print(f"\n检测到Conda环境: {conda_env_path}")
custom_env = input(f"使用此环境? [Y/n] 或输入其他路径: ").strip()
if custom_env.lower() == 'n':
conda_env_path = input("请输入Conda环境完整路径: ").strip()
elif custom_env and custom_env.lower() != 'y':
conda_env_path = custom_env
# 目标阴离子
anion_input = input("🎯 请输入目标阴离子 (逗号分隔) [默认: O,S,Cl,Br]: ").strip()
anions = set(a.strip() for a in anion_input.split(',')) if anion_input else {'O', 'S', 'Cl', 'Br'}
# 阴离子模式
print("\n阴离子模式:")
print(" 1. 仅单一阴离子")
print(" 2. 仅复合阴离子")
print(" 3. 全部 (默认)")
mode_choice = input("请选择 [1/2/3]: ").strip()
anion_mode = {'1': 'single', '2': 'mixed', '3': 'all', '': 'all'}.get(mode_choice, 'all')
# 并行配置
print("\n" + "" * 50)
print("【并行计算配置】")
default_cores = min(env['total_cores'], 32)
cores_input = input(f"💻 最大可用核数/Worker数 [默认: {default_cores}]: ").strip()
max_workers = int(cores_input) if cores_input.isdigit() else default_cores
params = {
'database_path': db_path,
'target_cation': cation,
'target_anions': anions,
'anion_mode': anion_mode,
'max_workers': max_workers,
'conda_env': conda_env_path,
}
print("\n" + "" * 60)
print("开始数据库分析...")
print("" * 60)
# 创建分析器
analyzer = DatabaseAnalyzer(
database_path=params['database_path'],
target_cation=params['target_cation'],
target_anions=params['target_anions'],
anion_mode=params['anion_mode'],
max_cores=params['max_workers'],
task_complexity='medium'
)
print(f"\n发现 {len(analyzer.cif_files)} 个CIF文件")
# 执行分析
report = analyzer.analyze(show_progress=True)
# 打印报告
ReportGenerator.print_report(report, detailed=True)
# 保存选项
save_choice = input("\n是否保存报告? [y/N]: ").strip().lower()
if save_choice == 'y':
output_path = input("报告路径 [默认: analysis_report.json]: ").strip()
output_path = output_path or "analysis_report.json"
report.save(output_path)
print(f"✅ 报告已保存到: {output_path}")
# CSV导出
csv_choice = input("是否导出详细CSV? [y/N]: ").strip().lower()
if csv_choice == 'y':
csv_path = input("CSV路径 [默认: analysis_details.csv]: ").strip()
csv_path = csv_path or "analysis_details.csv"
ReportGenerator.export_to_csv(report, csv_path)
# 生成最终数据库
process_choice = input("\n是否生成最终可用的数据库(扩胞+添加化合价)? [Y/n]: ").strip().lower()
if process_choice == 'n':
print("\n已跳过扩胞处理,可稍后继续")
return params
# 输出目录设置
print("\n输出目录设置:")
flat_dir = input(" 原始格式输出目录 [默认: workspace/processed]: ").strip()
flat_dir = flat_dir or "workspace/processed"
analysis_dir = input(" 分析格式输出目录 [默认: workspace/data]: ").strip()
analysis_dir = analysis_dir or "workspace/data"
# 扩胞保存数量
keep_input = input("\n扩胞结构保存数量 [默认: 1]: ").strip()
keep_number = int(keep_input) if keep_input.isdigit() and int(keep_input) > 0 else 1
# 扩胞精度选择
print("\n扩胞计算精度:")
print(" 1. 高精度 (精确分数)")
print(" 2. 普通精度 (分母≤100)")
print(" 3. 低精度 (分母≤10) [默认]")
print(" 4. 极低精度 (分母≤5)")
precision_choice = input("请选择 [1/2/3/4]: ").strip()
calculate_type = {
'1': 'high', '2': 'normal', '3': 'low', '4': 'very_low', '': 'low'
}.get(precision_choice, 'low')
# 获取可处理文件
processable = report.get_processable_files(include_needs_expansion=True)
if not processable:
print("⚠️ 没有可处理的文件")
return params
print(f"\n发现 {len(processable)} 个可处理的文件")
# 准备文件列表和扩胞标记
input_files = [info.file_path for info in processable]
needs_expansion_flags = [info.needs_expansion for info in processable]
anion_types_list = [info.anion_types for info in processable]
direct_count = sum(1 for f in needs_expansion_flags if not f)
expansion_count = sum(1 for f in needs_expansion_flags if f)
print(f" - 可直接处理: {direct_count}")
print(f" - 需要扩胞: {expansion_count}")
print(f" - 扩胞保存数: {keep_number}")
confirm = input("\n确认开始处理? [Y/n]: ").strip().lower()
if confirm == 'n':
print("\n已取消处理")
return params
print("\n" + "" * 60)
print("开始处理结构文件...")
print("" * 60)
# 创建处理器
processor = StructureProcessor(
calculate_type=calculate_type,
keep_number=keep_number,
target_cation=params['target_cation']
)
# 创建输出目录
os.makedirs(flat_dir, exist_ok=True)
os.makedirs(analysis_dir, exist_ok=True)
results = []
total = len(input_files)
import shutil
for i, (input_file, needs_exp, anion_types) in enumerate(
zip(input_files, needs_expansion_flags, anion_types_list)
):
print(f"\r处理进度: {i+1}/{total} - {os.path.basename(input_file)}", end="")
# 处理文件到原始格式目录
result = processor.process_file(input_file, flat_dir, needs_exp)
results.append(result)
if result.success:
# 同时保存到分析格式目录
# 按阴离子类型创建子目录
anion_key = '+'.join(sorted(anion_types)) if anion_types else 'other'
anion_dir = os.path.join(analysis_dir, anion_key)
os.makedirs(anion_dir, exist_ok=True)
# 获取基础文件名
base_name = os.path.splitext(os.path.basename(input_file))[0]
# 创建以文件名命名的子目录
file_dir = os.path.join(anion_dir, base_name)
os.makedirs(file_dir, exist_ok=True)
# 复制生成的文件到分析格式目录
for output_file in result.output_files:
dst_path = os.path.join(file_dir, os.path.basename(output_file))
shutil.copy2(output_file, dst_path)
print()
# 统计结果
success_count = sum(1 for r in results if r.success)
fail_count = sum(1 for r in results if not r.success)
total_output = sum(len(r.output_files) for r in results if r.success)
print("\n" + "-" * 60)
print("【处理结果统计】")
print("-" * 60)
print(f" 成功处理: {success_count}")
print(f" 处理失败: {fail_count}")
print(f" 生成文件: {total_output}")
print(f"\n 原始格式目录: {flat_dir}")
print(f" 分析格式目录: {analysis_dir}")
print(f" └── 结构: data/阴离子类型/文件名/文件名.cif")
# 显示失败的文件
if fail_count > 0:
print("\n失败的文件:")
for r in results:
if not r.success:
print(f" - {os.path.basename(r.input_file)}: {r.error_message}")
print("\n✅ Step 1 完成!")
return params
def run_step2_zeo_analysis(params: dict, ws_manager: WorkspaceManager = None, workspace_info = None):
"""
Step 2-4: Zeo++ Voronoi 分析
Args:
params: 参数字典,包含 target_cation 等
ws_manager: 工作区管理器(可选,如果已有则直接使用)
workspace_info: 工作区信息(可选,如果已有则直接使用)
"""
print("\n" + "" * 60)
print("【Step 2-4: Zeo++ Voronoi 分析】")
print("" * 60)
# 如果没有传入 ws_manager则创建新的
if ws_manager is None:
# 工作区路径
workspace_path = input("\n工作区路径 [默认: workspace]: ").strip() or "workspace"
tool_dir = input("工具目录路径 [默认: tool]: ").strip() or "tool"
# 创建工作区管理器
ws_manager = WorkspaceManager(
workspace_path=workspace_path,
tool_dir=tool_dir,
target_cation=params.get('target_cation', 'Li')
)
# 如果没有传入 workspace_info则检查现有工作区
if workspace_info is None:
existing = ws_manager.check_existing_workspace()
if existing and existing.total_structures > 0:
ws_manager.print_workspace_summary(existing)
workspace_info = existing
else:
print("⚠️ 工作区数据目录不存在或为空")
print(f" 请先运行 Step 1 生成处理后的数据到: {ws_manager.data_dir}")
return
# 检查是否需要创建软链接
if workspace_info.linked_structures < workspace_info.total_structures:
print("\n正在创建配置文件软链接...")
workspace_info = ws_manager.setup_workspace(force_relink=False)
# 获取计算任务
tasks = ws_manager.get_computation_tasks(workspace_info)
if not tasks:
print("⚠️ 没有找到可计算的任务")
return
print(f"\n发现 {len(tasks)} 个计算任务")
# Zeo++ 环境配置
print("\n" + "-" * 50)
print("【Zeo++ 计算配置】")
default_zeo_env = "/cluster/home/koko125/anaconda3/envs/zeo"
zeo_env = input(f"Zeo++ Conda环境 [默认: {default_zeo_env}]: ").strip()
zeo_env = zeo_env or default_zeo_env
# SLURM 配置
partition = input("SLURM分区 [默认: cpu]: ").strip() or "cpu"
max_concurrent_input = input("最大并发任务数 [默认: 50]: ").strip()
max_concurrent = int(max_concurrent_input) if max_concurrent_input.isdigit() else 50
time_limit = input("单任务时间限制 [默认: 2:00:00]: ").strip() or "2:00:00"
# 创建配置
zeo_config = ZeoConfig(
conda_env=zeo_env,
partition=partition,
max_concurrent=max_concurrent,
time_limit=time_limit
)
# 确认执行
print("\n" + "-" * 50)
print("【计算任务确认】")
print(f" 总任务数: {len(tasks)}")
print(f" Conda环境: {zeo_config.conda_env}")
print(f" SLURM分区: {zeo_config.partition}")
print(f" 最大并发: {zeo_config.max_concurrent}")
print(f" 时间限制: {zeo_config.time_limit}")
confirm = input("\n确认提交计算任务? [Y/n]: ").strip().lower()
if confirm == 'n':
print("已取消")
return
# 创建执行器并运行
executor = ZeoExecutor(zeo_config)
log_dir = os.path.join(ws_manager.workspace_path, "slurm_logs")
results = executor.run_batch(tasks, output_dir=log_dir)
# 打印结果摘要
executor.print_results_summary(results)
# 保存结果
results_file = os.path.join(ws_manager.workspace_path, "zeo_results.json")
results_data = [
{
'task_id': r.task_id,
'structure_name': r.structure_name,
'cif_path': r.cif_path,
'success': r.success,
'output_files': r.output_files,
'error_message': r.error_message
}
for r in results
]
with open(results_file, 'w') as f:
json.dump(results_data, f, indent=2)
print(f"\n结果已保存到: {results_file}")
print("\n✅ Step 2-4 完成!")
def run_step5_result_processing(workspace_path: str = "workspace"):
"""
Step 5: 结果处理与筛选
Args:
workspace_path: 工作区路径
"""
print("\n" + "" * 60)
print("【Step 5: 结果处理与筛选】")
print("" * 60)
# 创建结果处理器
processor = ResultProcessor(workspace_path=workspace_path)
# 获取筛选条件
print("\n设置筛选条件:")
print(" (直接回车使用默认值,输入 0 表示不限制)")
# 最小渗透直径(默认改为 1.0
perc_input = input(" 最小渗透直径 (Å) [默认: 1.0]: ").strip()
min_percolation = float(perc_input) if perc_input else 1.0
# 最小 d 值
d_input = input(" 最小 d 值 [默认: 2.0]: ").strip()
min_d = float(d_input) if d_input else 2.0
# 最大节点长度
node_input = input(" 最大节点长度 (Å) [默认: 不限制]: ").strip()
max_node = float(node_input) if node_input else float('inf')
# 创建筛选条件
criteria = FilterCriteria(
min_percolation_diameter=min_percolation,
min_d_value=min_d,
max_node_length=max_node
)
# 执行处理
results, stats = processor.process_and_filter(
criteria=criteria,
save_csv=True,
copy_passed=True
)
# 打印摘要
processor.print_summary(results, stats)
# 显示输出位置
print("\n输出文件位置:")
print(f" 汇总 CSV: {workspace_path}/results/summary.csv")
print(f" 分类 CSV: {workspace_path}/results/阴离子类型/阴离子类型.csv")
print(f" 通过筛选: {workspace_path}/passed/阴离子类型/结构名/")
print("\n✅ Step 5 完成!")
def main():
"""主函数"""
print_banner()
# 环境检测
env = detect_and_show_environment()
# 询问目标阳离子
cation = input("\n🎯 请输入目标阳离子 [默认: Li]: ").strip() or "Li"
# 检测工作流程状态
status = detect_workflow_status(workspace_path="workspace", target_cation=cation)
print_workflow_status(status)
# 获取用户选择
choice = get_user_choice(status)
if choice == 'exit':
print("\n👋 再见!")
return
# 根据选择执行相应步骤
params = {'target_cation': cation}
if choice == 'step1':
# 从头开始
params = run_step1_database_analysis(env, cation)
if params is None:
return
# 询问是否继续
continue_choice = input("\n是否继续进行 Zeo++ 计算? [Y/n]: ").strip().lower()
if continue_choice == 'n':
print("\n可稍后运行程序继续 Zeo++ 计算")
return
# 重新检测状态
status = detect_workflow_status(workspace_path="workspace", target_cation=cation)
run_step2_zeo_analysis(params, status['ws_manager'], status['workspace_info'])
# 询问是否继续筛选
continue_choice = input("\n是否继续进行结果筛选? [Y/n]: ").strip().lower()
if continue_choice == 'n':
print("\n可稍后运行程序继续结果筛选")
return
run_step5_result_processing("workspace")
elif choice == 'step2':
# 从 Zeo++ 计算开始
run_step2_zeo_analysis(params, status['ws_manager'], status['workspace_info'])
# 询问是否继续筛选
continue_choice = input("\n是否继续进行结果筛选? [Y/n]: ").strip().lower()
if continue_choice == 'n':
print("\n可稍后运行程序继续结果筛选")
return
run_step5_result_processing("workspace")
elif choice == 'step5':
# 直接进行结果筛选
run_step5_result_processing("workspace")
print("\n✅ 全部完成!")
if __name__ == "__main__":
main()

349
readme.md
View File

@@ -1,4 +1,4 @@
# 高通量筛选与扩胞项目
# 高通量筛选与扩胞项目 v2.3
## 环境配置需求
@@ -12,9 +12,29 @@
### 2. screen 环境 (用于逻辑筛选与数据处理)
* **Python**: 3.11.4
* **核心库**: `pymatgen==2024.11.13`, `pandas` (新增用于处理CSV)
* **路径**: `/cluster/home/koko125/anaconda3/envs/screen`
## 快速开始
### 方式一:使用新版主程序(推荐)
```bash
# 激活 screen 环境
conda activate /cluster/home/koko125/anaconda3/envs/screen
# 运行主程序
python main.py
```
主程序提供交互式界面,支持:
- 数据库分析与筛选
- 本地多进程并行
- SLURM 直接提交(无需生成脚本文件)
- 实时进度条显示
- 扩胞处理与化合价添加
### 方式二:传统方式
1. **数据准备**:
* 如果数据来源为 **Materials Project (MP)**,请将 CIF 文件放入 `data/input_pre`
* 如果数据来源为 **ICSD**,请直接将 CIF 文件放入 `data/input`
@@ -25,6 +45,52 @@
bash main.sh
```
## 新版功能特性 (v2.1)
### 执行模式
1. **本地多进程模式** (`local`)
- 使用 Python multiprocessing 在本地并行执行
- 适合小规模任务或测试
2. **SLURM 直接提交模式** (`slurm`)
- 直接在 Python 中提交 SLURM 作业
- 无需生成脚本文件
- 实时监控作业状态和进度
- 适合大规模高通量计算
### 进度显示
```
分析CIF文件: |████████████████████░░░░░░░░░░░░░░░░░░░░| 500/1000 (50.0%) [0:05:23<0:05:20, 1.6it/s] ✓480 ✗20
```
### 输出格式
处理后的文件支持两种输出格式:
1. **原始格式**(平铺)
```
workspace/processed/
├── 1514027.cif
├── 1514072.cif
└── ...
```
2. **分析格式**(按阴离子分类)
```
workspace/data/
├── O/
│ ├── 1514027/
│ │ └── 1514027.cif
│ └── 1514072/
│ └── 1514072.cif
├── S/
│ └── ...
└── Cl+O/
└── ...
```
## 处理流程详解
### Stage 1: 预处理与基础筛选 (Step 1)
@@ -51,9 +117,12 @@
---
## 扩胞逻辑 (Step 5 - 待后续执行)
## 扩胞逻辑 (Step 5)
目前扩胞逻辑维持原状,基于筛选后的结构进行处理。
扩胞处理已集成到新版主程序中,支持:
- 自动计算扩胞因子
- 可选保存数量当为1时不加后缀
- 自动添加化合价信息
### 算法分解
1. **读取结构**: 解析 CIF 文件。
@@ -71,4 +140,276 @@
### 假设条件
* 只考虑两个原子在同一位置上的共占位情况。
* 不考虑 Li 原子的共占位情况,对 Li 原子不做处理。
* 不考虑 Li 原子的共占位情况,对 Li 原子不做处理。
## 项目结构
```
screen/
├── main.py # 主入口(新版)
├── main.sh # 传统脚本入口
├── readme.md # 本文档
├── config/ # 配置文件
│ ├── settings.yaml
│ └── valence_states.yaml
├── src/ # 源代码
│ ├── analysis/ # 分析模块
│ │ ├── database_analyzer.py
│ │ ├── report_generator.py
│ │ ├── structure_inspector.py
│ │ └── worker.py
│ ├── core/ # 核心模块
│ │ ├── executor.py # 任务执行器(新)
│ │ ├── scheduler.py # 调度器
│ │ └── progress.py # 进度管理
│ ├── preprocessing/ # 预处理模块
│ │ ├── processor.py # 结构处理器
│ │ └── ...
│ └── utils/ # 工具函数
├── py/ # 传统脚本
├── tool/ # 工具和配置
│ ├── analyze_voronoi_nodes.py
│ └── Li/ # 化合价配置
└── workspace/ # 工作区
├── data/ # 分析格式输出
└── processed/ # 原始格式输出
```
## API 使用示例
### 使用执行器
```python
from src.core.executor import create_executor, TaskExecutor
# 创建执行器
executor = create_executor(
mode="slurm", # 或 "local"
max_workers=32,
conda_env="/cluster/home/koko125/anaconda3/envs/screen"
)
# 定义任务
tasks = [
(file_path, "Li", {"O", "S"})
for file_path in cif_files
]
# 执行
from src.analysis.worker import analyze_single_file
results = executor.run(tasks, analyze_single_file, desc="分析CIF文件")
```
### 使用数据库分析器
```python
from src.analysis.database_analyzer import DatabaseAnalyzer
analyzer = DatabaseAnalyzer(
database_path="/path/to/cif/files",
target_cation="Li",
target_anions={"O", "S", "Cl", "Br"},
anion_mode="all"
)
report = analyzer.analyze(show_progress=True)
report.save("analysis_report.json")
```
### 使用 Zeo++ 执行器
```python
from src.computation.workspace_manager import WorkspaceManager
from src.computation.zeo_executor import ZeoExecutor, ZeoConfig
# 设置工作区
ws_manager = WorkspaceManager(
workspace_path="workspace",
tool_dir="tool",
target_cation="Li"
)
# 创建软链接
workspace_info = ws_manager.setup_workspace()
# 获取计算任务
tasks = ws_manager.get_computation_tasks(workspace_info)
# 配置 Zeo++ 执行器
config = ZeoConfig(
conda_env="/cluster/home/koko125/anaconda3/envs/zeo",
partition="cpu",
max_concurrent=50,
time_limit="2:00:00"
)
# 执行计算
executor = ZeoExecutor(config)
results = executor.run_batch(tasks, output_dir="slurm_logs")
executor.print_results_summary(results)
```
## 新版功能特性 (v2.2)
### Zeo++ Voronoi 分析
新增 SLURM 作业数组支持,可高效调度大量 Zeo++ 计算任务:
1. **自动工作区管理**
- 检测现有工作区数据
- 自动创建配置文件软链接
- 按阴离子类型组织目录结构
2. **SLURM 作业数组**
- 使用 `--array` 参数批量提交任务
- 支持最大并发数限制(如 `%50`
- 自动分批处理超大任务集
3. **实时进度监控**
- 通过状态文件跟踪任务完成情况
- 支持 Ctrl+C 中断监控(作业继续运行)
- 自动收集输出文件
### 工作流程
```
Step 1: 数据库分析
Step 1.5: 扩胞处理 + 化合价添加
Step 2-4: Zeo++ Voronoi 分析
├── 创建软链接 (yaml 配置 + 计算脚本)
├── 提交 SLURM 作业数组
└── 监控进度并收集结果
Step 5: 结果处理与筛选
├── 从 log.txt 提取关键参数
├── 汇总到 CSV 文件
├── 应用筛选条件
└── 复制通过筛选的结构到 passed/ 目录
```
### 筛选条件
Step 5 支持以下筛选条件:
- **最小渗透直径** (Percolation Diameter): 默认 1.0 Å
- **最小 d 值** (Minimum of d): 默认 2.0
- **最大节点长度** (Maximum Node Length): 默认不限制
### 日志输出
Zeo++ 计算的输出会重定向到每个结构目录下的 `log.txt` 文件:
```bash
python analyze_voronoi_nodes.py *.cif -i O.yaml > log.txt 2>&1
```
日志中包含的关键信息:
- `Percolation diameter (A): X.XX` - 渗透直径
- `the minium of d\nX.XX` - 最小 d 值
- `Maximum node length detected: X.XX A` - 最大节点长度
### 目录结构
```
workspace/
├── data/ # 分析格式数据
│ ├── O/ # 氧化物
│ │ ├── O.yaml -> tool/Li/O.yaml
│ │ ├── analyze_voronoi_nodes.py -> tool/analyze_voronoi_nodes.py
│ │ ├── 1514027/
│ │ │ ├── 1514027.cif
│ │ │ ├── 1514027_all_accessed_node.cif # Zeo++ 输出
│ │ │ ├── 1514027_bond_valence_filtered.cif
│ │ │ └── 1514027_bv_info.csv
│ │ └── ...
│ ├── S/ # 硫化物
│ └── Cl+O/ # 复合阴离子
├── processed/ # 原始格式数据
├── slurm_logs/ # SLURM 日志
│ ├── tasks.json
│ ├── submit_array.sh
│ ├── task_0.out
│ ├── task_0.err
│ ├── status_0.txt
│ └── ...
└── zeo_results.json # 计算结果汇总
```
### 配置文件说明
`tool/Li/O.yaml` 示例:
```yaml
SPECIE: Li+
ANION: O
PERCO_R: 0.5
NEIGHBOR: 1.8
LONG: 2.2
```
参数说明:
- `SPECIE`: 目标扩散离子(带电荷)
- `ANION`: 阴离子类型
- `PERCO_R`: 渗透半径阈值
- `NEIGHBOR`: 邻近距离阈值
- `LONG`: 长节点判定阈值
## 新版功能特性 (v2.3)
### 断点续做功能
v2.3 新增智能断点续做功能,支持从任意步骤继续执行:
1. **自动状态检测**
- 启动时自动检测工作流程状态
- 检测 `workspace/data/` 是否存在且有结构 → 判断 Step 1 是否完成
- 检测结构目录下是否有 `log.txt` → 判断 Zeo++ 计算是否完成
- 如果 50% 以上结构有 log.txt认为 Zeo++ 计算已完成
2. **智能流程跳转**
- 如果已完成 Zeo++ 计算 → 可直接进行筛选
- 如果已完成扩胞处理 → 可直接进行 Zeo++ 计算
- 从后往前检测,自动跳过已完成的步骤
3. **分步执行与中断**
- 每个大步骤完成后询问是否继续
- 支持中途退出,下次运行时自动检测进度
- 三大步骤:扩胞与加化合价 → Zeo++ 计算 → 筛选
### 工作流程状态示例
```
工作流程状态检测
检测到现有工作区数据:
- 结构总数: 1234
- 已完成 Zeo++ 计算: 1200 (97.2%)
- 未完成 Zeo++ 计算: 34
当前状态: Zeo++ 计算已完成
可选操作:
[1] 直接进行结果筛选 (Step 5)
[2] 重新运行 Zeo++ 计算 (Step 2-4)
[3] 从头开始 (Step 1)
[0] 退出
请选择 [1]:
```
### 使用场景
1. **首次运行**
- 从 Step 1 开始完整执行
- 每步完成后可选择继续或退出
2. **中断后继续**
- 自动检测已完成的步骤
- 提供从当前进度继续的选项
3. **重新筛选**
- 修改筛选条件后
- 可直接运行 Step 5 而无需重新计算
4. **部分重算**
- 如需重新计算部分结构
- 可选择重新运行 Zeo++ 计算

12
src/__init__.py Normal file
View File

@@ -0,0 +1,12 @@
"""
高通量筛选与扩胞项目 - 源代码包
"""
from . import analysis
from . import core
from . import preprocessing
from . import computation
from . import utils
__version__ = "2.2.0"
__all__ = ['analysis', 'core', 'preprocessing', 'computation', 'utils']

0
src/analysis/__init__.py Normal file
View File

View File

@@ -0,0 +1,468 @@
"""
数据库分析器:支持高性能并行分析
"""
import os
import pickle
import json
from dataclasses import dataclass, field, asdict
from typing import Dict, List, Set, Optional
from pathlib import Path
from .structure_inspector import StructureInspector, StructureInfo
from .worker import analyze_single_file
from ..core.scheduler import ParallelScheduler, ResourceConfig
# 在 DatabaseReport 类中添加缺失的字段
@dataclass
class DatabaseReport:
"""数据库分析报告"""
# 基础统计
database_path: str = ""
total_files: int = 0
valid_files: int = 0
invalid_files: int = 0
# 目标元素统计
target_cation: str = ""
target_anions: Set[str] = field(default_factory=set)
anion_mode: str = ""
# 含目标阳离子的统计
cation_containing_count: int = 0
cation_containing_ratio: float = 0.0
# 阴离子分布
anion_distribution: Dict[str, int] = field(default_factory=dict)
anion_ratios: Dict[str, float] = field(default_factory=dict)
single_anion_count: int = 0
mixed_anion_count: int = 0
# 数据质量统计
with_oxidation_states: int = 0
without_oxidation_states: int = 0
needs_expansion_count: int = 0
cation_with_vacancy_count: int = 0 # Li与空位共占位新增
cation_with_other_cation_count: int = 0 # Li与其他阳离子共占位新增
anion_partial_occupancy_count: int = 0
binary_compound_count: int = 0
has_water_count: int = 0
has_radioactive_count: int = 0
# 可处理性统计
directly_processable: int = 0
needs_preprocessing: int = 0
cannot_process: int = 0
# 详细信息
all_structures: List[StructureInfo] = field(default_factory=list)
skip_reasons_summary: Dict[str, int] = field(default_factory=dict)
# 扩胞相关统计(新增)
expansion_stats: Dict[str, int] = field(default_factory=lambda: {
'no_expansion_needed': 0,
'expansion_factor_2': 0,
'expansion_factor_3': 0,
'expansion_factor_4_8': 0,
'expansion_factor_large': 0,
'cannot_expand': 0,
})
expansion_factor_distribution: Dict[int, int] = field(default_factory=dict)
def to_dict(self) -> dict:
"""转换为可序列化的字典"""
from dataclasses import fields as dataclass_fields
def convert_value(val):
"""递归转换值为可序列化类型"""
if isinstance(val, set):
return list(val)
elif isinstance(val, dict):
return {k: convert_value(v) for k, v in val.items()}
elif isinstance(val, list):
return [convert_value(item) for item in val]
elif hasattr(val, '__dataclass_fields__'):
# 处理 dataclass 对象
return {k: convert_value(v) for k, v in asdict(val).items()}
else:
return val
result = {}
for f in dataclass_fields(self):
value = getattr(self, f.name)
result[f.name] = convert_value(value)
return result
def save(self, path: str):
"""保存报告到JSON文件"""
with open(path, 'w', encoding='utf-8') as f:
json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
print(f"✅ 报告已保存到: {path}")
@classmethod
def load(cls, path: str) -> 'DatabaseReport':
"""从JSON文件加载报告"""
with open(path, 'r', encoding='utf-8') as f:
d = json.load(f)
# 处理 set 类型
if 'target_anions' in d:
d['target_anions'] = set(d['target_anions'])
# 处理 StructureInfo 列表(简化处理,不恢复完整对象)
if 'all_structures' in d:
d['all_structures'] = []
return cls(**d)
def get_processable_files(self, include_needs_expansion: bool = True) -> List[StructureInfo]:
"""
获取可处理的文件列表
Args:
include_needs_expansion: 是否包含需要扩胞的文件
Returns:
可处理的 StructureInfo 列表
"""
result = []
for info in self.all_structures:
if info is None or not info.is_valid:
continue
if not info.contains_target_cation:
continue
if not info.can_process:
continue
if not include_needs_expansion and info.needs_expansion:
continue
result.append(info)
return result
def copy_processable_files(
self,
output_dir: str,
include_needs_expansion: bool = True,
organize_by_anion: bool = True
) -> Dict[str, int]:
"""
将可处理的CIF文件复制到工作区
Args:
output_dir: 输出目录(如 workspace/data
include_needs_expansion: 是否包含需要扩胞的文件
organize_by_anion: 是否按阴离子类型组织子目录
Returns:
复制统计信息 {类别: 数量}
"""
import shutil
# 创建输出目录
os.makedirs(output_dir, exist_ok=True)
# 获取可处理文件
processable = self.get_processable_files(include_needs_expansion)
stats = {
'direct': 0, # 可直接处理
'needs_expansion': 0, # 需要扩胞
'total': 0
}
# 按类型创建子目录
if organize_by_anion:
anion_dirs = {}
for info in processable:
# 确定目标目录
if organize_by_anion and info.anion_types:
# 使用主要阴离子作为目录名
anion_key = '+'.join(sorted(info.anion_types))
if anion_key not in anion_dirs:
anion_dir = os.path.join(output_dir, anion_key)
os.makedirs(anion_dir, exist_ok=True)
anion_dirs[anion_key] = anion_dir
target_dir = anion_dirs[anion_key]
else:
target_dir = output_dir
# 进一步按处理类型分类
if info.needs_expansion:
sub_dir = os.path.join(target_dir, 'needs_expansion')
stats['needs_expansion'] += 1
else:
sub_dir = os.path.join(target_dir, 'direct')
stats['direct'] += 1
os.makedirs(sub_dir, exist_ok=True)
# 复制文件
src_path = info.file_path
dst_path = os.path.join(sub_dir, info.file_name)
try:
shutil.copy2(src_path, dst_path)
stats['total'] += 1
except Exception as e:
print(f"⚠️ 复制失败 {info.file_name}: {e}")
# 打印统计
print(f"\n📁 文件已复制到: {output_dir}")
print(f" 可直接处理: {stats['direct']}")
print(f" 需要扩胞: {stats['needs_expansion']}")
print(f" 总计: {stats['total']}")
return stats
class DatabaseAnalyzer:
"""数据库分析器 - 支持高性能并行"""
def __init__(
self,
database_path: str,
target_cation: str = "Li",
target_anions: Set[str] = None,
anion_mode: str = "all",
max_cores: int = 4,
task_complexity: str = "medium"
):
"""
初始化分析器
Args:
database_path: 数据库路径
target_cation: 目标阳离子
target_anions: 目标阴离子集合
anion_mode: 阴离子模式
max_cores: 最大可用核数
task_complexity: 任务复杂度 ('low', 'medium', 'high')
"""
self.database_path = database_path
self.target_cation = target_cation
self.target_anions = target_anions or {'O', 'S', 'Cl', 'Br'}
self.anion_mode = anion_mode
self.max_cores = max_cores
self.task_complexity = task_complexity
# 获取文件列表
self.cif_files = self._get_cif_files()
# 配置调度器
self.resource_config = ParallelScheduler.recommend_config(
num_tasks=len(self.cif_files),
task_complexity=task_complexity,
max_cores=max_cores
)
self.scheduler = ParallelScheduler(self.resource_config)
def _get_cif_files(self) -> List[str]:
"""获取所有CIF文件路径"""
cif_files = []
if os.path.isfile(self.database_path):
if self.database_path.endswith('.cif'):
cif_files.append(self.database_path)
else:
for root, dirs, files in os.walk(self.database_path):
for f in files:
if f.endswith('.cif'):
cif_files.append(os.path.join(root, f))
return sorted(cif_files)
def analyze(self, show_progress: bool = True) -> DatabaseReport:
"""
执行并行分析
Args:
show_progress: 是否显示进度
Returns:
DatabaseReport: 分析报告
"""
report = DatabaseReport(
database_path=self.database_path,
target_cation=self.target_cation,
target_anions=self.target_anions,
anion_mode=self.anion_mode,
total_files=len(self.cif_files)
)
if report.total_files == 0:
print(f"⚠️ 警告: 在 {self.database_path} 中未找到CIF文件")
return report
# 准备任务
tasks = [
(f, self.target_cation, self.target_anions)
for f in self.cif_files
]
# 执行并行分析
results = self.scheduler.run_local(
tasks=tasks,
worker_func=analyze_single_file,
desc="分析CIF文件"
)
# 过滤有效结果
report.all_structures = [r for r in results if r is not None]
# 统计结果
self._compute_statistics(report)
return report
def analyze_slurm(
self,
output_dir: str,
job_name: str = "cif_analysis"
) -> str:
"""
提交SLURM作业进行分析
Args:
output_dir: 输出目录
job_name: 作业名称
Returns:
作业ID
"""
os.makedirs(output_dir, exist_ok=True)
# 保存任务配置
tasks_file = os.path.join(output_dir, "tasks.json")
with open(tasks_file, 'w') as f:
json.dump({
'files': self.cif_files,
'target_cation': self.target_cation,
'target_anions': list(self.target_anions),
'anion_mode': self.anion_mode
}, f)
# 生成SLURM脚本
worker_script = os.path.join(
os.path.dirname(__file__), 'worker.py'
)
script = self.scheduler.generate_slurm_script(
tasks_file=tasks_file,
worker_script=worker_script,
output_dir=output_dir,
job_name=job_name
)
# 保存并提交
script_path = os.path.join(output_dir, "submit.sh")
return self.scheduler.submit_slurm_job(script, script_path)
# 更新 _compute_statistics 方法
def _compute_statistics(self, report: DatabaseReport):
"""计算统计数据"""
for info in report.all_structures:
# 确保 info 不是 None
if info is None:
report.invalid_files += 1
continue
# 检查有效性
if info.is_valid:
report.valid_files += 1
else:
report.invalid_files += 1
continue # 无效文件不继续统计
# 关键修复:只有当结构确实含有目标阳离子时才计入统计
if not info.contains_target_cation:
continue # 不含目标阳离子的文件不继续统计
report.cation_containing_count += 1
for anion in info.anion_types:
report.anion_distribution[anion] = \
report.anion_distribution.get(anion, 0) + 1
if info.anion_mode == "single":
report.single_anion_count += 1
elif info.anion_mode == "mixed":
report.mixed_anion_count += 1
# 根据阴离子模式过滤
if self.anion_mode == "single" and info.anion_mode != "single":
continue
if self.anion_mode == "mixed" and info.anion_mode != "mixed":
continue
if info.anion_mode == "none":
continue
# 各项统计
if info.has_oxidation_states:
report.with_oxidation_states += 1
else:
report.without_oxidation_states += 1
# Li共占位统计修改
if info.cation_with_vacancy:
report.cation_with_vacancy_count += 1
if info.cation_with_other_cation:
report.cation_with_other_cation_count += 1
if info.anion_has_partial_occupancy:
report.anion_partial_occupancy_count += 1
if info.is_binary_compound:
report.binary_compound_count += 1
if info.has_water_molecule:
report.has_water_count += 1
if info.has_radioactive_elements:
report.has_radioactive_count += 1
# 可处理性
if info.can_process:
if info.needs_expansion:
report.needs_preprocessing += 1
else:
report.directly_processable += 1
else:
report.cannot_process += 1
if info.skip_reason:
for reason in info.skip_reason.split("; "):
report.skip_reasons_summary[reason] = \
report.skip_reasons_summary.get(reason, 0) + 1
# 扩胞统计(新增)
exp_info = info.expansion_info
factor = exp_info.expansion_factor
if not exp_info.needs_expansion:
report.expansion_stats['no_expansion_needed'] += 1
elif not exp_info.can_expand:
report.expansion_stats['cannot_expand'] += 1
elif factor == 2:
report.expansion_stats['expansion_factor_2'] += 1
elif factor == 3:
report.expansion_stats['expansion_factor_3'] += 1
elif 4 <= factor <= 8:
report.expansion_stats['expansion_factor_4_8'] += 1
else:
report.expansion_stats['expansion_factor_large'] += 1
# 详细分布
if exp_info.needs_expansion and exp_info.can_expand:
report.expansion_factor_distribution[factor] = \
report.expansion_factor_distribution.get(factor, 0) + 1
report.needs_expansion_count += 1
# 计算比例
if report.valid_files > 0:
report.cation_containing_ratio = \
report.cation_containing_count / report.valid_files
if report.cation_containing_count > 0:
for anion, count in report.anion_distribution.items():
report.anion_ratios[anion] = \
count / report.cation_containing_count

View File

@@ -0,0 +1,159 @@
"""
报告生成器:生成格式化的分析报告
"""
from typing import Optional
from .database_analyzer import DatabaseReport
class ReportGenerator:
"""报告生成器"""
@staticmethod
def print_report(report: DatabaseReport, detailed: bool = False):
"""打印分析报告"""
print("\n" + "=" * 70)
print(" 数据库分析报告")
print("=" * 70)
# 基础信息
print(f"\n📁 数据库路径: {report.database_path}")
print(f"🎯 目标阳离子: {report.target_cation}")
print(f"🎯 目标阴离子: {', '.join(sorted(report.target_anions))}")
print(f"🎯 阴离子模式: {report.anion_mode}")
# 基础统计
print("\n" + "-" * 70)
print("【1. 基础统计】")
print("-" * 70)
print(f" 总 CIF 文件数: {report.total_files}")
print(f" 有效文件数: {report.valid_files}")
print(f" 无效文件数: {report.invalid_files}")
print(f"{report.target_cation} 化合物数: {report.cation_containing_count}")
print(f"{report.target_cation} 化合物占比: {report.cation_containing_ratio:.1%}")
# 阴离子分布
print("\n" + "-" * 70)
print(f"【2. 阴离子分布】(在含 {report.target_cation} 的化合物中)")
print("-" * 70)
if report.anion_distribution:
for anion in sorted(report.anion_distribution.keys()):
count = report.anion_distribution[anion]
ratio = report.anion_ratios.get(anion, 0)
bar = "" * int(ratio * 30)
print(f" {anion:5s}: {count:6d} ({ratio:6.1%}) {bar}")
print(f"\n 单一阴离子化合物: {report.single_anion_count}")
print(f" 复合阴离子化合物: {report.mixed_anion_count}")
# 数据质量
print("\n" + "-" * 70)
print("【3. 数据质量检查】")
print("-" * 70)
total_target = report.cation_containing_count
if total_target > 0:
print(f" 含化合价信息: {report.with_oxidation_states:6d} "
f"({report.with_oxidation_states / total_target:.1%})")
print(f" 缺化合价信息: {report.without_oxidation_states:6d} "
f"({report.without_oxidation_states / total_target:.1%})")
print()
print(f" {report.target_cation}与空位共占位(无需处理): {report.cation_with_vacancy_count:6d}")
print(f" {report.target_cation}与阳离子共占位(需扩胞): {report.cation_with_other_cation_count:6d}")
print(f" 阴离子共占位: {report.anion_partial_occupancy_count:6d}")
print(f" 需扩胞处理(总计): {report.needs_expansion_count:6d}")
print()
print(f" 二元化合物: {report.binary_compound_count:6d}")
print(f" 含水分子: {report.has_water_count:6d}")
print(f" 含放射性元素: {report.has_radioactive_count:6d}")
# 可处理性评估
print("\n" + "-" * 70)
print("【4. 可处理性评估】")
print("-" * 70)
total_processable = report.directly_processable + report.needs_preprocessing
print(f" ✅ 可直接处理: {report.directly_processable:6d}")
print(f" ⚠️ 需预处理(扩胞): {report.needs_preprocessing:6d}")
print(f" ❌ 无法处理: {report.cannot_process:6d}")
print(f" ─────────────────────────────")
print(f" 📊 可处理总数: {total_processable:6d}")
# 跳过原因汇总
if report.skip_reasons_summary:
print("\n" + "-" * 70)
print("【5. 无法处理的原因统计】")
print("-" * 70)
sorted_reasons = sorted(
report.skip_reasons_summary.items(),
key=lambda x: x[1],
reverse=True
)
for reason, count in sorted_reasons:
print(f" {reason:35s}: {count:6d}")
# 扩胞分析
print("\n" + "-" * 70)
print("【6. 扩胞需求分析】")
print("-" * 70)
exp = report.expansion_stats
if total_processable > 0:
print(f" 无需扩胞: {exp['no_expansion_needed']:6d}")
print(f" 扩胞因子=2: {exp['expansion_factor_2']:6d}")
print(f" 扩胞因子=3: {exp['expansion_factor_3']:6d}")
print(f" 扩胞因子=4~8: {exp['expansion_factor_4_8']:6d}")
print(f" 扩胞因子>8: {exp['expansion_factor_large']:6d}")
print(f" 无法扩胞(因子过大): {exp['cannot_expand']:6d}")
# 详细分布
if detailed and report.expansion_factor_distribution:
print("\n 扩胞因子分布:")
for factor in sorted(report.expansion_factor_distribution.keys()):
count = report.expansion_factor_distribution[factor]
bar = "" * min(count, 30)
print(f" {factor:3d}x: {count:5d} {bar}")
print("\n" + "=" * 70)
@staticmethod
def export_to_csv(report: DatabaseReport, output_path: str):
"""导出详细结果到CSV"""
import csv
with open(output_path, 'w', newline='', encoding='utf-8') as f:
writer = csv.writer(f)
# 写入表头
headers = [
'file_name', 'is_valid', 'contains_target_cation',
'anion_types', 'anion_mode', 'has_oxidation_states',
'has_partial_occupancy', 'cation_partial_occupancy',
'anion_partial_occupancy', 'needs_expansion',
'is_binary', 'has_water', 'has_radioactive',
'can_process', 'skip_reason'
]
writer.writerow(headers)
# 写入数据
for info in report.all_structures:
row = [
info.file_name,
info.is_valid,
info.contains_target_cation,
'+'.join(sorted(info.anion_types)) if info.anion_types else '',
info.anion_mode,
info.has_oxidation_states,
info.has_partial_occupancy,
info.cation_with_other_cation, # 修复:使用正确的属性名
info.anion_has_partial_occupancy,
info.needs_expansion,
info.is_binary_compound,
info.has_water_molecule,
info.has_radioactive_elements,
info.can_process,
info.skip_reason
]
writer.writerow(row)
print(f"详细结果已导出到: {output_path}")

View File

@@ -0,0 +1,443 @@
"""
结构检查器对单个CIF文件进行深度分析含扩胞需求判断
"""
from dataclasses import dataclass, field
from typing import Set, Dict, List, Optional, Tuple
from pymatgen.core import Structure
from pymatgen.core.periodic_table import Element, Specie
from collections import defaultdict
from fractions import Fraction
from functools import reduce
import math
import re
import os
@dataclass
class OccupancyInfo:
"""共占位信息"""
occupation: float # 占据率
atom_serials: List[int] = field(default_factory=list) # 原子序号
elements: List[str] = field(default_factory=list) # 涉及的元素
numerator: int = 0 # 分子
denominator: int = 1 # 分母
involves_target_cation: bool = False # 是否涉及目标阳离子
involves_anion: bool = False # 是否涉及阴离子
@dataclass
class ExpansionInfo:
"""扩胞信息"""
needs_expansion: bool = False # 是否需要扩胞
expansion_factor: int = 1 # 扩胞因子(最小公倍数)
occupancy_details: List[OccupancyInfo] = field(default_factory=list) # 共占位详情
problematic_sites: int = 0 # 问题位点数
can_expand: bool = True # 是否可以扩胞处理
skip_reason: str = "" # 无法扩胞的原因
@dataclass
class StructureInfo:
"""单个结构的分析结果"""
file_path: str
file_name: str
# 基础信息
is_valid: bool = False
error_message: str = ""
# 元素组成
elements: Set[str] = field(default_factory=set)
num_sites: int = 0
formula: str = ""
# 阳离子/阴离子信息
contains_target_cation: bool = False
anion_types: Set[str] = field(default_factory=set)
anion_mode: str = "" # "single", "mixed", "none"
# 数据质量标记
has_oxidation_states: bool = False
has_partial_occupancy: bool = False # 是否有共占位
has_water_molecule: bool = False
has_radioactive_elements: bool = False
is_binary_compound: bool = False
# 共占位详细分析(新增)
cation_with_vacancy: bool = False # Li与空位共占位不需处理
cation_with_other_cation: bool = False # Li与其他阳离子共占位需扩胞
anion_has_partial_occupancy: bool = False # 阴离子共占位
other_has_partial_occupancy: bool = False # 其他元素共占位(需扩胞)
expansion_info: ExpansionInfo = field(default_factory=ExpansionInfo)
# 可处理性
needs_expansion: bool = False
can_process: bool = False
skip_reason: str = ""
class StructureInspector:
"""结构检查器(含扩胞分析)"""
# 预定义的阴离子集合
VALID_ANIONS = {'O', 'S', 'Cl', 'Br'}
# 放射性元素
RADIOACTIVE_ELEMENTS = {
'U', 'Th', 'Pu', 'Ra', 'Rn', 'Po', 'Np', 'Am',
'Cm', 'Bk', 'Cf', 'Es', 'Fm', 'Md', 'No', 'Lr'
}
# 扩胞精度模式
PRECISION_LIMITS = {
'high': None, # 精确分数
'normal': 100, # 分母≤100
'low': 10, # 分母≤10
'very_low': 5 # 分母≤5
}
def __init__(
self,
target_cation: str = "Li",
target_anions: Set[str] = None,
expansion_precision: str = "low"
):
"""
初始化检查器
Args:
target_cation: 目标阳离子 (如 "Li", "Na")
target_anions: 目标阴离子集合 (如 {"O", "S"})
expansion_precision: 扩胞计算精度 ('high', 'normal', 'low', 'very_low')
"""
self.target_cation = target_cation
self.target_anions = target_anions or self.VALID_ANIONS
self.expansion_precision = expansion_precision
# 目标阳离子的各种可能表示形式
self.target_cation_variants = {
target_cation,
f"{target_cation}+",
f"{target_cation}1+",
}
def inspect(self, file_path: str) -> StructureInfo:
"""分析单个CIF文件"""
import os
info = StructureInfo(
file_path=file_path,
file_name=os.path.basename(file_path)
)
# 尝试读取结构
try:
structure = Structure.from_file(file_path)
except Exception as e:
info.is_valid = False
info.error_message = f"读取CIF失败: {str(e)}"
return info
info.is_valid = True
try:
# ===== 关键修复:正确提取元素符号 =====
# structure.composition.elements 返回 Element 对象列表
# 需要用 .symbol 属性获取字符串
element_symbols = set()
for el in structure.composition.elements:
# el 是 Element 对象el.symbol 是字符串如 "Li"
element_symbols.add(el.symbol)
info.elements = element_symbols
info.num_sites = structure.num_sites
info.formula = structure.composition.reduced_formula
# 检查是否为二元化合物
info.is_binary_compound = len(element_symbols) == 2
# ===== 关键修复:直接比较字符串 =====
info.contains_target_cation = self.target_cation in element_symbols
# 检查阴离子类型
info.anion_types = element_symbols.intersection(self.target_anions)
if len(info.anion_types) == 0:
info.anion_mode = "none"
elif len(info.anion_types) == 1:
info.anion_mode = "single"
else:
info.anion_mode = "mixed"
# 检查氧化态
info.has_oxidation_states = self._check_oxidation_states(structure)
# 检查共占位(核心分析)
try:
self._analyze_partial_occupancy(structure, info)
except Exception as e:
# 共占位分析失败,记录但继续
pass
# 检查水分子
try:
info.has_water_molecule = self._check_water_molecule(structure)
except:
info.has_water_molecule = False
# 检查放射性元素
info.has_radioactive_elements = bool(
info.elements.intersection(self.RADIOACTIVE_ELEMENTS)
)
# 判断可处理性
self._evaluate_processability(info)
except Exception as e:
# 分析过程出错,但文件本身是有效的
# 保留 is_valid = True但记录错误
info.error_message = f"分析过程出错: {str(e)}"
return info
def _check_oxidation_states(self, structure: Structure) -> bool:
"""检查结构是否包含氧化态信息"""
try:
for site in structure.sites:
for specie in site.species.keys():
if isinstance(specie, Specie):
return True
return False
except:
return False
def _get_element_from_species_string(self, species_str: str) -> str:
"""从物种字符串提取纯元素符号"""
match = re.match(r'([A-Z][a-z]?)', species_str)
return match.group(1) if match else ""
def _get_occupancy_from_species_string(self, species_str: str, exclude_elements: Set[str]) -> Optional[float]:
"""
从物种字符串获取非目标元素的占据率
格式如: "Li+:0.689, Sc3+:0.311"
"""
if ':' not in species_str:
return None
parts = [p.strip() for p in species_str.split(',')]
for part in parts:
if ':' in part:
element_part, occu_part = part.split(':')
element = self._get_element_from_species_string(element_part.strip())
if element and element not in exclude_elements:
try:
return float(occu_part.strip())
except ValueError:
continue
return None
# 在 StructureInspector 类中,替换 _analyze_partial_occupancy 方法
def _analyze_partial_occupancy(self, structure: Structure, info: StructureInfo):
"""
分析共占位情况(修正版)
关键规则:
- Li与空位共占位 → 不需要处理cation_with_vacancy
- Li与其他阳离子共占位 → 需要扩胞cation_with_other_cation
- 阴离子共占位 → 通常不处理
- 其他阳离子共占位 → 需要扩胞
"""
occupancy_dict = defaultdict(list) # {occupation: [site_indices]}
occupancy_elements = {} # {occupation: [elements]}
for i, site in enumerate(structure.sites):
site_species = site.species
species_string = str(site.species)
# 提取各元素及其占据率
species_occu = {} # {element: occupancy}
for sp, occu in site_species.items():
elem = sp.symbol if hasattr(sp, 'symbol') else str(sp)
elem = self._get_element_from_species_string(elem)
if elem:
species_occu[elem] = occu
total_occupancy = sum(species_occu.values())
elements_at_site = list(species_occu.keys())
# 检查是否有部分占据
has_partial = any(occu < 1.0 for occu in species_occu.values()) or len(species_occu) > 1
if not has_partial:
continue
info.has_partial_occupancy = True
# 判断Li的共占位情况
if self.target_cation in elements_at_site:
li_occu = species_occu.get(self.target_cation, 0)
other_elements = [e for e in elements_at_site if e != self.target_cation]
if not other_elements and li_occu < 1.0:
# Li与空位共占位Li占据率<1但没有其他元素
info.cation_with_vacancy = True
elif other_elements:
# Li与其他元素共占位
other_are_anions = all(e in self.target_anions for e in other_elements)
if other_are_anions:
# Li与阴离子共占位罕见标记为阴离子共占位
info.anion_has_partial_occupancy = True
else:
# Li与其他阳离子共占位 → 需要扩胞
info.cation_with_other_cation = True
# 记录需要扩胞的占据率取非Li元素的占据率
for elem in other_elements:
if elem not in self.target_anions:
occu = species_occu.get(elem, 0)
if occu > 0 and occu < 1.0:
occupancy_dict[occu].append(i)
occupancy_elements[occu] = elements_at_site
else:
# 不涉及Li的位点
# 判断是否涉及阴离子
if any(elem in self.target_anions for elem in elements_at_site):
info.anion_has_partial_occupancy = True
else:
# 其他阳离子的共占位 → 需要扩胞
info.other_has_partial_occupancy = True
# 获取占据率
for elem, occu in species_occu.items():
if occu > 0 and occu < 1.0:
occupancy_dict[occu].append(i)
occupancy_elements[occu] = elements_at_site
break # 只记录一次
# 计算扩胞信息
self._calculate_expansion_info(info, occupancy_dict, occupancy_elements)
def _evaluate_processability(self, info: StructureInfo):
"""评估可处理性(修正版)"""
skip_reasons = []
if not info.is_valid:
skip_reasons.append("无法解析CIF文件")
if not info.contains_target_cation:
skip_reasons.append(f"不含{self.target_cation}")
if info.anion_mode == "none":
skip_reasons.append("不含目标阴离子")
if info.is_binary_compound:
skip_reasons.append("二元化合物")
if info.has_radioactive_elements:
skip_reasons.append("含放射性元素")
# Li与空位共占位 → 不需要处理不加入skip_reasons
# info.cation_with_vacancy 不影响可处理性
# Li与其他阳离子共占位 → 需要扩胞(如果扩胞因子合理则可处理)
if info.cation_with_other_cation:
if info.expansion_info.can_expand:
info.needs_expansion = True
else:
skip_reasons.append(f"{self.target_cation}与其他阳离子共占位且{info.expansion_info.skip_reason}")
# 阴离子共占位 → 不处理
if info.anion_has_partial_occupancy:
skip_reasons.append("阴离子存在共占位")
if info.has_water_molecule:
skip_reasons.append("含水分子")
# 其他阳离子共占位不涉及Li→ 需要扩胞
if info.other_has_partial_occupancy:
if info.expansion_info.can_expand:
info.needs_expansion = True
else:
skip_reasons.append(info.expansion_info.skip_reason)
if skip_reasons:
info.can_process = False
info.skip_reason = "; ".join(skip_reasons)
else:
info.can_process = True
def _calculate_expansion_info(
self,
info: StructureInfo,
occupancy_dict: Dict[float, List[int]],
occupancy_elements: Dict[float, List[str]]
):
"""计算扩胞相关信息"""
expansion_info = ExpansionInfo()
if not occupancy_dict:
info.expansion_info = expansion_info
return
# 需要扩胞(有非目标阳离子的共占位)
expansion_info.needs_expansion = True
expansion_info.problematic_sites = sum(len(v) for v in occupancy_dict.values())
# 转换为OccupancyInfo列表
occupancy_list = []
for occu, serials in occupancy_dict.items():
elements = occupancy_elements.get(occu, [])
# 根据精度计算分数
limit = self.PRECISION_LIMITS.get(self.expansion_precision)
if limit:
fraction = Fraction(occu).limit_denominator(limit)
else:
fraction = Fraction(occu).limit_denominator()
occ_info = OccupancyInfo(
occupation=occu,
atom_serials=[s + 1 for s in serials], # 转为1-based
elements=elements,
numerator=fraction.numerator,
denominator=fraction.denominator,
involves_target_cation=self.target_cation in elements,
involves_anion=any(e in self.target_anions for e in elements)
)
occupancy_list.append(occ_info)
expansion_info.occupancy_details = occupancy_list
# 计算最小公倍数(扩胞因子)
denominators = [occ.denominator for occ in occupancy_list]
if denominators:
lcm = reduce(lambda a, b: a * b // math.gcd(a, b), denominators, 1)
expansion_info.expansion_factor = lcm
# 判断是否可以扩胞(因子过大则不可处理)
if lcm > 64: # 扩胞超过64倍通常不可行
expansion_info.can_expand = False
expansion_info.skip_reason = f"扩胞因子过大({lcm})"
info.expansion_info = expansion_info
info.needs_expansion = expansion_info.needs_expansion and expansion_info.can_expand
def _check_water_molecule(self, structure: Structure) -> bool:
"""检查是否含有水分子"""
try:
oxygen_sites = []
hydrogen_sites = []
for site in structure.sites:
species_str = str(site.species)
if 'O' in species_str:
oxygen_sites.append(site)
if 'H' in species_str:
hydrogen_sites.append(site)
for o_site in oxygen_sites:
nearby_h = [h for h in hydrogen_sites if o_site.distance(h) < 1.2]
if len(nearby_h) >= 2:
return True
return False
except:
return False

199
src/analysis/worker.py Normal file
View File

@@ -0,0 +1,199 @@
"""
工作进程:处理单个分析任务
设计为可以独立运行用于SLURM作业数组
"""
import os
import pickle
from typing import List, Tuple, Optional
from dataclasses import asdict, fields
from .structure_inspector import StructureInspector, StructureInfo
def analyze_single_file(args: Tuple[str, str, set]) -> Optional[StructureInfo]:
"""
分析单个CIF文件Worker函数
Args:
args: (file_path, target_cation, target_anions)
Returns:
StructureInfo 或 None如果分析失败
"""
file_path, target_cation, target_anions = args
try:
inspector = StructureInspector(
target_cation=target_cation,
target_anions=target_anions
)
result = inspector.inspect(file_path)
return result
except Exception as e:
# 返回一个标记失败的结果(而不是 None
return StructureInfo(
file_path=file_path,
file_name=os.path.basename(file_path),
is_valid=False,
error_message=f"Worker异常: {str(e)}"
)
def structure_info_to_dict(info: StructureInfo) -> dict:
"""
将 StructureInfo 转换为可序列化的字典
处理 set、dataclass 等特殊类型
"""
result = {}
for field in fields(info):
value = getattr(info, field.name)
# 处理 set 类型
if isinstance(value, set):
result[field.name] = list(value)
# 处理嵌套的 dataclass (如 ExpansionInfo)
elif hasattr(value, '__dataclass_fields__'):
result[field.name] = asdict(value)
# 处理 list 中可能包含的 dataclass
elif isinstance(value, list):
result[field.name] = [
asdict(item) if hasattr(item, '__dataclass_fields__') else item
for item in value
]
else:
result[field.name] = value
return result
def dict_to_structure_info(d: dict) -> StructureInfo:
"""
从字典恢复 StructureInfo 对象
"""
from .structure_inspector import ExpansionInfo, OccupancyInfo
# 处理 set 类型字段
if 'elements' in d and isinstance(d['elements'], list):
d['elements'] = set(d['elements'])
if 'anion_types' in d and isinstance(d['anion_types'], list):
d['anion_types'] = set(d['anion_types'])
if 'target_anions' in d and isinstance(d['target_anions'], list):
d['target_anions'] = set(d['target_anions'])
# 处理 ExpansionInfo
if 'expansion_info' in d and isinstance(d['expansion_info'], dict):
exp_dict = d['expansion_info']
# 处理 OccupancyInfo 列表
if 'occupancy_details' in exp_dict:
exp_dict['occupancy_details'] = [
OccupancyInfo(**occ) if isinstance(occ, dict) else occ
for occ in exp_dict['occupancy_details']
]
d['expansion_info'] = ExpansionInfo(**exp_dict)
return StructureInfo(**d)
def batch_analyze(
file_paths: List[str],
target_cation: str,
target_anions: set,
output_file: str = None
) -> List[StructureInfo]:
"""
批量分析文件用于SLURM子任务
Args:
file_paths: CIF文件路径列表
target_cation: 目标阳离子
target_anions: 目标阴离子集合
output_file: 输出文件路径pickle格式
Returns:
StructureInfo列表
"""
results = []
inspector = StructureInspector(
target_cation=target_cation,
target_anions=target_anions
)
for file_path in file_paths:
try:
info = inspector.inspect(file_path)
results.append(info)
except Exception as e:
results.append(StructureInfo(
file_path=file_path,
file_name=os.path.basename(file_path),
is_valid=False,
error_message=str(e)
))
# 保存结果
if output_file:
serializable_results = [structure_info_to_dict(r) for r in results]
with open(output_file, 'wb') as f:
pickle.dump(serializable_results, f)
return results
def load_results(result_file: str) -> List[StructureInfo]:
"""
从pickle文件加载结果
"""
with open(result_file, 'rb') as f:
data = pickle.load(f)
return [dict_to_structure_info(d) for d in data]
def merge_results(result_files: List[str]) -> List[StructureInfo]:
"""
合并多个结果文件用于汇总SLURM作业数组的输出
"""
all_results = []
for f in result_files:
if os.path.exists(f):
all_results.extend(load_results(f))
return all_results
# 用于SLURM作业数组的命令行入口
if __name__ == "__main__":
import argparse
import json
parser = argparse.ArgumentParser(description="CIF Analysis Worker")
parser.add_argument("--tasks-file", required=True, help="任务文件路径(JSON)")
parser.add_argument("--output-dir", required=True, help="输出目录")
parser.add_argument("--task-id", type=int, default=0, help="任务ID(用于数组作业)")
parser.add_argument("--num-workers", type=int, default=1, help="并行worker数")
args = parser.parse_args()
# 加载任务
with open(args.tasks_file, 'r') as f:
task_config = json.load(f)
file_paths = task_config['files']
target_cation = task_config['target_cation']
target_anions = set(task_config['target_anions'])
# 如果是数组作业,只处理分配的部分
if args.task_id >= 0:
chunk_size = len(file_paths) // args.num_workers + 1
start_idx = args.task_id * chunk_size
end_idx = min(start_idx + chunk_size, len(file_paths))
file_paths = file_paths[start_idx:end_idx]
# 输出文件
output_file = os.path.join(args.output_dir, f"results_{args.task_id}.pkl")
# 执行分析
print(f"Worker {args.task_id}: 处理 {len(file_paths)} 个文件")
results = batch_analyze(file_paths, target_cation, target_anions, output_file)
print(f"Worker {args.task_id}: 完成,结果保存到 {output_file}")

View File

@@ -0,0 +1,15 @@
"""
计算模块Zeo++ Voronoi 分析
"""
from .workspace_manager import WorkspaceManager
from .zeo_executor import ZeoExecutor, ZeoConfig
from .result_processor import ResultProcessor, FilterCriteria, StructureResult
__all__ = [
'WorkspaceManager',
'ZeoExecutor',
'ZeoConfig',
'ResultProcessor',
'FilterCriteria',
'StructureResult'
]

View File

@@ -0,0 +1,426 @@
"""
Zeo++ 计算结果处理器:提取数据、筛选结构
"""
import os
import re
import shutil
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass, field
import pandas as pd
@dataclass
class FilterCriteria:
"""筛选条件"""
min_percolation_diameter: float = 1.0 # 最小渗透直径 (Å),默认 1.0
min_d_value: float = 2.0 # 最小 d 值,默认 2.0
max_node_length: float = float('inf') # 最大节点长度 (Å)
@dataclass
class StructureResult:
"""单个结构的计算结果"""
structure_name: str
anion_type: str
work_dir: str
# 提取的参数
percolation_diameter: Optional[float] = None
min_d: Optional[float] = None
max_node_length: Optional[float] = None
# 筛选结果
passed_filter: bool = False
filter_reason: str = ""
class ResultProcessor:
"""
Zeo++ 计算结果处理器
功能:
1. 从每个结构目录的 log.txt 提取关键参数
2. 汇总所有结果到 CSV 文件
3. 根据筛选条件筛选结构
4. 将通过筛选的结构复制到新文件夹
"""
def __init__(
self,
workspace_path: str = "workspace",
data_dir: str = None,
output_dir: str = None
):
"""
初始化结果处理器
Args:
workspace_path: 工作区根目录
data_dir: 数据目录(默认 workspace/data
output_dir: 输出目录(默认 workspace/results
"""
self.workspace_path = os.path.abspath(workspace_path)
self.data_dir = data_dir or os.path.join(self.workspace_path, "data")
self.output_dir = output_dir or os.path.join(self.workspace_path, "results")
def extract_from_log(self, log_path: str) -> Tuple[Optional[float], Optional[float], Optional[float]]:
"""
从 log.txt 中提取三个关键参数
Args:
log_path: log.txt 文件路径
Returns:
(percolation_diameter, min_d, max_node_length)
"""
if not os.path.exists(log_path):
return None, None, None
try:
with open(log_path, 'r', encoding='utf-8') as f:
content = f.read()
except Exception:
return None, None, None
# 正则表达式 - 与 py/extract_data.py 保持一致
# 1. Percolation diameter: "# Percolation diameter (A): 1.06"
re_percolation = r"Percolation diameter \(A\):\s*([\d\.]+)"
# 2. Minimum of d: "the minium of d\n3.862140561244235"
# 注意:这是 Topological_Analysis 库输出的格式
re_min_d = r"the minium of d\s*\n\s*([\d\.]+)"
# 3. Maximum node length: "# Maximum node length detected: 1.332 A"
re_max_node = r"Maximum node length detected:\s*([\d\.]+)\s*A"
# 提取数据
match_perc = re.search(re_percolation, content)
match_d = re.search(re_min_d, content)
match_node = re.search(re_max_node, content)
val_perc = float(match_perc.group(1)) if match_perc else None
val_d = float(match_d.group(1)) if match_d else None
val_node = float(match_node.group(1)) if match_node else None
return val_perc, val_d, val_node
def process_all_structures(self) -> List[StructureResult]:
"""
处理所有结构,提取计算结果
Returns:
StructureResult 列表
"""
results = []
if not os.path.exists(self.data_dir):
print(f"⚠️ 数据目录不存在: {self.data_dir}")
return results
print("\n正在提取计算结果...")
# 遍历阴离子目录
for anion_key in os.listdir(self.data_dir):
anion_dir = os.path.join(self.data_dir, anion_key)
if not os.path.isdir(anion_dir):
continue
# 遍历结构目录
for struct_name in os.listdir(anion_dir):
struct_dir = os.path.join(anion_dir, struct_name)
if not os.path.isdir(struct_dir):
continue
# 查找 log.txt
log_path = os.path.join(struct_dir, "log.txt")
# 提取参数
perc, min_d, max_node = self.extract_from_log(log_path)
result = StructureResult(
structure_name=struct_name,
anion_type=anion_key,
work_dir=struct_dir,
percolation_diameter=perc,
min_d=min_d,
max_node_length=max_node
)
results.append(result)
print(f" 共处理 {len(results)} 个结构")
return results
def apply_filter(
self,
results: List[StructureResult],
criteria: FilterCriteria
) -> List[StructureResult]:
"""
应用筛选条件
Args:
results: 结构结果列表
criteria: 筛选条件
Returns:
更新后的结果列表(包含筛选状态)
"""
print("\n应用筛选条件...")
print(f" 最小渗透直径: {criteria.min_percolation_diameter} Å")
print(f" 最小 d 值: {criteria.min_d_value}")
print(f" 最大节点长度: {criteria.max_node_length} Å")
passed_count = 0
for result in results:
# 检查是否有有效数据
if result.percolation_diameter is None or result.min_d is None:
result.passed_filter = False
result.filter_reason = "数据缺失"
continue
# 检查渗透直径
if result.percolation_diameter < criteria.min_percolation_diameter:
result.passed_filter = False
result.filter_reason = f"渗透直径 {result.percolation_diameter:.3f} < {criteria.min_percolation_diameter}"
continue
# 检查 d 值
if result.min_d < criteria.min_d_value:
result.passed_filter = False
result.filter_reason = f"d 值 {result.min_d:.3f} < {criteria.min_d_value}"
continue
# 检查节点长度(如果有数据)
if result.max_node_length is not None:
if result.max_node_length > criteria.max_node_length:
result.passed_filter = False
result.filter_reason = f"节点长度 {result.max_node_length:.3f} > {criteria.max_node_length}"
continue
# 通过所有筛选
result.passed_filter = True
result.filter_reason = "通过"
passed_count += 1
print(f" 通过筛选: {passed_count}/{len(results)}")
return results
def save_summary_csv(
self,
results: List[StructureResult],
output_path: str = None
) -> str:
"""
保存汇总 CSV 文件
Args:
results: 结构结果列表
output_path: 输出路径(默认 workspace/results/summary.csv
Returns:
CSV 文件路径
"""
if output_path is None:
os.makedirs(self.output_dir, exist_ok=True)
output_path = os.path.join(self.output_dir, "summary.csv")
# 构建数据
data = []
for r in results:
data.append({
'Structure': r.structure_name,
'Anion_Type': r.anion_type,
'Percolation_Diameter_A': r.percolation_diameter,
'Min_d': r.min_d,
'Max_Node_Length_A': r.max_node_length,
'Passed_Filter': r.passed_filter,
'Filter_Reason': r.filter_reason
})
df = pd.DataFrame(data)
# 按阴离子类型和结构名排序
df = df.sort_values(['Anion_Type', 'Structure'])
# 保存
os.makedirs(os.path.dirname(output_path), exist_ok=True)
df.to_csv(output_path, index=False)
print(f"\n汇总 CSV 已保存: {output_path}")
return output_path
def save_anion_csv(
self,
results: List[StructureResult],
output_dir: str = None
) -> List[str]:
"""
按阴离子类型分别保存 CSV 文件
Args:
results: 结构结果列表
output_dir: 输出目录
Returns:
生成的 CSV 文件路径列表
"""
if output_dir is None:
output_dir = self.output_dir
# 按阴离子类型分组
anion_groups: Dict[str, List[StructureResult]] = {}
for r in results:
if r.anion_type not in anion_groups:
anion_groups[r.anion_type] = []
anion_groups[r.anion_type].append(r)
csv_files = []
for anion_type, group_results in anion_groups.items():
# 构建数据
data = []
for r in group_results:
data.append({
'Structure': r.structure_name,
'Percolation_Diameter_A': r.percolation_diameter,
'Min_d': r.min_d,
'Max_Node_Length_A': r.max_node_length,
'Passed_Filter': r.passed_filter,
'Filter_Reason': r.filter_reason
})
df = pd.DataFrame(data)
df = df.sort_values('Structure')
# 保存到对应目录
anion_output_dir = os.path.join(output_dir, anion_type)
os.makedirs(anion_output_dir, exist_ok=True)
csv_path = os.path.join(anion_output_dir, f"{anion_type}.csv")
df.to_csv(csv_path, index=False)
csv_files.append(csv_path)
print(f" {anion_type}: {len(group_results)} 个结构 -> {csv_path}")
return csv_files
def copy_passed_structures(
self,
results: List[StructureResult],
output_dir: str = None
) -> int:
"""
将通过筛选的结构复制到新文件夹
Args:
results: 结构结果列表
output_dir: 输出目录(默认 workspace/passed
Returns:
复制的结构数量
"""
if output_dir is None:
output_dir = os.path.join(self.workspace_path, "passed")
passed_results = [r for r in results if r.passed_filter]
if not passed_results:
print("\n没有通过筛选的结构")
return 0
print(f"\n正在复制 {len(passed_results)} 个通过筛选的结构...")
copied = 0
for r in passed_results:
# 目标目录passed/阴离子类型/结构名/
dst_dir = os.path.join(output_dir, r.anion_type, r.structure_name)
try:
# 如果目标已存在,先删除
if os.path.exists(dst_dir):
shutil.rmtree(dst_dir)
# 复制整个目录
shutil.copytree(r.work_dir, dst_dir)
copied += 1
except Exception as e:
print(f" ⚠️ 复制失败 {r.structure_name}: {e}")
print(f" 已复制 {copied} 个结构到: {output_dir}")
return copied
def process_and_filter(
self,
criteria: FilterCriteria = None,
save_csv: bool = True,
copy_passed: bool = True
) -> Tuple[List[StructureResult], Dict]:
"""
完整的处理流程:提取数据 -> 筛选 -> 保存 CSV -> 复制通过的结构
Args:
criteria: 筛选条件(如果为 None则不筛选
save_csv: 是否保存 CSV
copy_passed: 是否复制通过筛选的结构
Returns:
(结果列表, 统计信息字典)
"""
# 1. 提取所有结构的计算结果
results = self.process_all_structures()
if not results:
return results, {'total': 0, 'passed': 0, 'failed': 0}
# 2. 应用筛选条件
if criteria is not None:
results = self.apply_filter(results, criteria)
# 3. 保存 CSV
if save_csv:
print("\n保存结果 CSV...")
self.save_summary_csv(results)
self.save_anion_csv(results)
# 4. 复制通过筛选的结构
if copy_passed and criteria is not None:
self.copy_passed_structures(results)
# 统计
stats = {
'total': len(results),
'passed': sum(1 for r in results if r.passed_filter),
'failed': sum(1 for r in results if not r.passed_filter),
'missing_data': sum(1 for r in results if r.filter_reason == "数据缺失")
}
return results, stats
def print_summary(self, results: List[StructureResult], stats: Dict):
"""打印结果摘要"""
print("\n" + "=" * 60)
print("【计算结果摘要】")
print("=" * 60)
print(f" 总结构数: {stats['total']}")
print(f" 通过筛选: {stats['passed']}")
print(f" 未通过筛选: {stats['failed']}")
print(f" 数据缺失: {stats.get('missing_data', 0)}")
# 按阴离子类型统计
anion_stats: Dict[str, Dict] = {}
for r in results:
if r.anion_type not in anion_stats:
anion_stats[r.anion_type] = {'total': 0, 'passed': 0}
anion_stats[r.anion_type]['total'] += 1
if r.passed_filter:
anion_stats[r.anion_type]['passed'] += 1
print("\n 按阴离子类型:")
for anion, s in sorted(anion_stats.items()):
print(f" {anion}: {s['passed']}/{s['total']} 通过")
print("=" * 60)

View File

@@ -0,0 +1,288 @@
"""
工作区管理器:管理计算工作区的创建和软链接
"""
import os
import shutil
from pathlib import Path
from typing import Dict, List, Optional, Set, Tuple
from dataclasses import dataclass, field
@dataclass
class WorkspaceInfo:
"""工作区信息"""
workspace_path: str
data_dir: str # workspace/data
tool_dir: str # tool 目录
target_cation: str
target_anions: Set[str]
# 统计信息
total_structures: int = 0
anion_counts: Dict[str, int] = field(default_factory=dict)
linked_structures: int = 0 # 已创建软链接的结构数
class WorkspaceManager:
"""
工作区管理器
负责:
1. 检测现有工作区
2. 创建软链接yaml 配置文件和计算脚本放在每个结构目录下)
3. 准备计算任务
"""
# 支持的阴离子及其配置文件
SUPPORTED_ANIONS = {'O', 'S', 'Cl', 'Br'}
def __init__(
self,
workspace_path: str = "workspace",
tool_dir: str = "tool",
target_cation: str = "Li"
):
"""
初始化工作区管理器
Args:
workspace_path: 工作区根目录
tool_dir: 工具目录(包含 yaml 配置和计算脚本)
target_cation: 目标阳离子
"""
self.workspace_path = os.path.abspath(workspace_path)
self.tool_dir = os.path.abspath(tool_dir)
self.target_cation = target_cation
# 数据目录
self.data_dir = os.path.join(self.workspace_path, "data")
def check_existing_workspace(self) -> Optional[WorkspaceInfo]:
"""
检查现有工作区
Returns:
WorkspaceInfo 如果存在,否则 None
"""
if not os.path.exists(self.data_dir):
return None
# 扫描数据目录
anion_counts = {}
total = 0
linked = 0
for item in os.listdir(self.data_dir):
item_path = os.path.join(self.data_dir, item)
if os.path.isdir(item_path):
# 可能是阴离子目录(如 O, S, O+S
# 统计其中的结构数量
count = 0
for sub_item in os.listdir(item_path):
sub_path = os.path.join(item_path, sub_item)
if os.path.isdir(sub_path):
# 检查是否包含 CIF 文件
cif_files = [f for f in os.listdir(sub_path) if f.endswith('.cif')]
if cif_files:
count += 1
# 检查是否已有软链接
yaml_files = [f for f in os.listdir(sub_path) if f.endswith('.yaml')]
if yaml_files:
linked += 1
if count > 0:
anion_counts[item] = count
total += count
if total == 0:
return None
return WorkspaceInfo(
workspace_path=self.workspace_path,
data_dir=self.data_dir,
tool_dir=self.tool_dir,
target_cation=self.target_cation,
target_anions=set(anion_counts.keys()),
total_structures=total,
anion_counts=anion_counts,
linked_structures=linked
)
def setup_workspace(
self,
target_anions: Set[str] = None,
force_relink: bool = False
) -> WorkspaceInfo:
"""
设置工作区:在每个结构目录下创建软链接
软链接规则:
- yaml 文件:使用与阴离子目录同名的 yaml如 O 目录用 O.yamlCl+O 目录用 Cl+O.yaml
- python 脚本analyze_voronoi_nodes.py
Args:
target_anions: 目标阴离子集合
force_relink: 是否强制重新创建软链接
Returns:
WorkspaceInfo
"""
if target_anions is None:
target_anions = self.SUPPORTED_ANIONS
# 确保数据目录存在
if not os.path.exists(self.data_dir):
raise FileNotFoundError(f"数据目录不存在: {self.data_dir}")
# 获取计算脚本路径
analyze_script = os.path.join(self.tool_dir, "analyze_voronoi_nodes.py")
if not os.path.exists(analyze_script):
raise FileNotFoundError(f"计算脚本不存在: {analyze_script}")
anion_counts = {}
total = 0
linked = 0
print("\n正在设置工作区软链接...")
# 遍历数据目录中的阴离子子目录
for anion_key in os.listdir(self.data_dir):
anion_dir = os.path.join(self.data_dir, anion_key)
if not os.path.isdir(anion_dir):
continue
# 确定使用哪个 yaml 配置文件
# 使用与阴离子目录同名的 yaml 文件(如 O.yaml, Cl+O.yaml
yaml_name = f"{anion_key}.yaml"
yaml_source = os.path.join(self.tool_dir, self.target_cation, yaml_name)
if not os.path.exists(yaml_source):
print(f" ⚠️ 配置文件不存在: {yaml_source}")
continue
# 统计并处理该阴离子目录下的所有结构
count = 0
for struct_name in os.listdir(anion_dir):
struct_dir = os.path.join(anion_dir, struct_name)
if not os.path.isdir(struct_dir):
continue
# 检查是否包含 CIF 文件
cif_files = [f for f in os.listdir(struct_dir) if f.endswith('.cif')]
if not cif_files:
continue
count += 1
# 在结构目录下创建软链接
yaml_link = os.path.join(struct_dir, yaml_name)
script_link = os.path.join(struct_dir, "analyze_voronoi_nodes.py")
# 创建 yaml 软链接
if os.path.exists(yaml_link) or os.path.islink(yaml_link):
if force_relink:
os.remove(yaml_link)
os.symlink(yaml_source, yaml_link)
linked += 1
else:
os.symlink(yaml_source, yaml_link)
linked += 1
# 创建计算脚本软链接
if os.path.exists(script_link) or os.path.islink(script_link):
if force_relink:
os.remove(script_link)
os.symlink(analyze_script, script_link)
else:
os.symlink(analyze_script, script_link)
if count > 0:
anion_counts[anion_key] = count
total += count
print(f"{anion_key}: {count} 个结构, 配置 -> {yaml_name}")
print(f"\n 总计: {total} 个结构, 新建软链接: {linked}")
return WorkspaceInfo(
workspace_path=self.workspace_path,
data_dir=self.data_dir,
tool_dir=self.tool_dir,
target_cation=self.target_cation,
target_anions=set(anion_counts.keys()),
total_structures=total,
anion_counts=anion_counts,
linked_structures=linked
)
def get_computation_tasks(
self,
workspace_info: WorkspaceInfo = None
) -> List[Dict]:
"""
获取所有计算任务
Returns:
任务列表,每个任务包含:
- cif_path: CIF 文件路径
- yaml_name: YAML 配置文件名(如 O.yaml
- work_dir: 工作目录(结构目录)
- anion_type: 阴离子类型
- structure_name: 结构名称
"""
if workspace_info is None:
workspace_info = self.check_existing_workspace()
if workspace_info is None:
return []
tasks = []
for anion_key in workspace_info.anion_counts.keys():
anion_dir = os.path.join(self.data_dir, anion_key)
yaml_name = f"{anion_key}.yaml"
# 遍历该阴离子目录下的所有结构
for struct_name in os.listdir(anion_dir):
struct_dir = os.path.join(anion_dir, struct_name)
if not os.path.isdir(struct_dir):
continue
# 查找 CIF 文件
cif_files = [f for f in os.listdir(struct_dir) if f.endswith('.cif')]
# 检查是否有 yaml 软链接
yaml_path = os.path.join(struct_dir, yaml_name)
if not os.path.exists(yaml_path):
continue
for cif_file in cif_files:
cif_path = os.path.join(struct_dir, cif_file)
tasks.append({
'cif_path': cif_path,
'yaml_name': yaml_name,
'work_dir': struct_dir,
'anion_type': anion_key,
'structure_name': struct_name,
'cif_name': cif_file
})
return tasks
def print_workspace_summary(self, workspace_info: WorkspaceInfo):
"""打印工作区摘要"""
print("\n" + "=" * 60)
print("【工作区摘要】")
print("=" * 60)
print(f" 工作区路径: {workspace_info.workspace_path}")
print(f" 数据目录: {workspace_info.data_dir}")
print(f" 目标阳离子: {workspace_info.target_cation}")
print(f" 总结构数: {workspace_info.total_structures}")
print(f" 已配置软链接: {workspace_info.linked_structures}")
print()
print(" 阴离子分布:")
for anion, count in sorted(workspace_info.anion_counts.items()):
print(f" - {anion}: {count} 个结构")
print("=" * 60)

View File

@@ -0,0 +1,446 @@
"""
Zeo++ 计算执行器:使用 SLURM 作业数组高效调度大量计算任务
"""
import os
import subprocess
import time
import json
import tempfile
from typing import List, Dict, Optional, Callable, Any
from dataclasses import dataclass, field
from enum import Enum
import threading
from ..core.progress import ProgressManager
@dataclass
class ZeoConfig:
"""Zeo++ 计算配置"""
# 环境配置
conda_env: str = "/cluster/home/koko125/anaconda3/envs/zeo"
# SLURM 配置
partition: str = "cpu"
time_limit: str = "2:00:00" # 单个任务时间限制
memory_per_task: str = "4G"
# 作业数组配置
max_array_size: int = 1000 # SLURM 作业数组最大大小
max_concurrent: int = 50 # 最大并发任务数
# 轮询配置
poll_interval: float = 5.0 # 状态检查间隔(秒)
# 过滤器配置
filters: List[str] = field(default_factory=lambda: [
"Ordered", "PropOxi", "VoroPerco", "Coulomb", "VoroBV", "VoroInfo", "MergeSite"
])
@dataclass
class ZeoTaskResult:
"""单个任务结果"""
task_id: int
structure_name: str
cif_path: str
success: bool
output_files: List[str] = field(default_factory=list)
error_message: str = ""
duration: float = 0.0
class ZeoExecutor:
"""
Zeo++ 计算执行器
使用 SLURM 作业数组高效调度大量 Voronoi 分析任务
"""
def __init__(self, config: ZeoConfig = None):
self.config = config or ZeoConfig()
self.progress_manager = None
self._stop_event = threading.Event()
def run_batch(
self,
tasks: List[Dict],
output_dir: str = None,
desc: str = "Zeo++ 计算"
) -> List[ZeoTaskResult]:
"""
批量执行 Zeo++ 计算
Args:
tasks: 任务列表,每个任务包含 cif_path, yaml_path, work_dir 等
output_dir: SLURM 日志输出目录
desc: 进度条描述
Returns:
ZeoTaskResult 列表
"""
if not tasks:
print("⚠️ 没有任务需要执行")
return []
total = len(tasks)
# 创建输出目录
if output_dir is None:
output_dir = os.path.join(os.getcwd(), "slurm_logs")
os.makedirs(output_dir, exist_ok=True)
print(f"\n{'='*60}")
print(f"【Zeo++ 批量计算】")
print(f"{'='*60}")
print(f" 总任务数: {total}")
print(f" Conda环境: {self.config.conda_env}")
print(f" SLURM分区: {self.config.partition}")
print(f" 最大并发: {self.config.max_concurrent}")
print(f" 日志目录: {output_dir}")
print(f"{'='*60}\n")
# 保存任务列表到文件
tasks_file = os.path.join(output_dir, "tasks.json")
with open(tasks_file, 'w') as f:
json.dump(tasks, f, indent=2)
# 生成并提交作业数组
if total <= self.config.max_array_size:
# 单个作业数组
return self._submit_array_job(tasks, output_dir, desc)
else:
# 分批提交多个作业数组
return self._submit_batched_arrays(tasks, output_dir, desc)
def _submit_array_job(
self,
tasks: List[Dict],
output_dir: str,
desc: str
) -> List[ZeoTaskResult]:
"""提交单个作业数组"""
total = len(tasks)
# 保存任务列表
tasks_file = os.path.join(output_dir, "tasks.json")
with open(tasks_file, 'w') as f:
json.dump(tasks, f, indent=2)
# 生成作业脚本
script_content = self._generate_array_script(
tasks_file=tasks_file,
output_dir=output_dir,
array_range=f"0-{total-1}%{self.config.max_concurrent}"
)
script_path = os.path.join(output_dir, "submit_array.sh")
with open(script_path, 'w') as f:
f.write(script_content)
os.chmod(script_path, 0o755)
print(f"生成作业脚本: {script_path}")
# 提交作业
result = subprocess.run(
['sbatch', script_path],
capture_output=True,
text=True
)
if result.returncode != 0:
print(f"❌ 作业提交失败: {result.stderr}")
return [ZeoTaskResult(
task_id=i,
structure_name=t.get('structure_name', ''),
cif_path=t.get('cif_path', ''),
success=False,
error_message=f"提交失败: {result.stderr}"
) for i, t in enumerate(tasks)]
# 提取作业 ID
job_id = result.stdout.strip().split()[-1]
print(f"✓ 作业已提交: {job_id}")
print(f" 作业数组范围: 0-{total-1}")
print(f" 最大并发: {self.config.max_concurrent}")
# 监控作业进度
return self._monitor_array_job(job_id, tasks, output_dir, desc)
def _submit_batched_arrays(
self,
tasks: List[Dict],
output_dir: str,
desc: str
) -> List[ZeoTaskResult]:
"""分批提交多个作业数组"""
total = len(tasks)
batch_size = self.config.max_array_size
num_batches = (total + batch_size - 1) // batch_size
print(f"任务数超过作业数组限制,分 {num_batches} 批提交")
all_results = []
for batch_idx in range(num_batches):
start_idx = batch_idx * batch_size
end_idx = min(start_idx + batch_size, total)
batch_tasks = tasks[start_idx:end_idx]
batch_output_dir = os.path.join(output_dir, f"batch_{batch_idx}")
os.makedirs(batch_output_dir, exist_ok=True)
print(f"\n--- 批次 {batch_idx + 1}/{num_batches} ---")
print(f"任务范围: {start_idx} - {end_idx - 1}")
batch_results = self._submit_array_job(
batch_tasks,
batch_output_dir,
f"{desc} (批次 {batch_idx + 1}/{num_batches})"
)
# 调整任务 ID
for r in batch_results:
r.task_id += start_idx
all_results.extend(batch_results)
return all_results
def _generate_array_script(
self,
tasks_file: str,
output_dir: str,
array_range: str
) -> str:
"""生成 SLURM 作业数组脚本"""
# 获取项目根目录
project_root = os.getcwd()
script = f"""#!/bin/bash
#SBATCH --job-name=zeo_array
#SBATCH --partition={self.config.partition}
#SBATCH --array={array_range}
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=1
#SBATCH --mem={self.config.memory_per_task}
#SBATCH --time={self.config.time_limit}
#SBATCH --output={output_dir}/task_%a.out
#SBATCH --error={output_dir}/task_%a.err
# ============================================
# Zeo++ Voronoi 分析 - 作业数组
# ============================================
echo "===== 任务信息 ====="
echo "作业ID: $SLURM_JOB_ID"
echo "数组任务ID: $SLURM_ARRAY_TASK_ID"
echo "节点: $SLURM_NODELIST"
echo "开始时间: $(date)"
echo "===================="
# ============ 环境初始化 ============
# 加载 bashrc
if [ -f ~/.bashrc ]; then
source ~/.bashrc
fi
# 初始化 Conda
if [ -f ~/anaconda3/etc/profile.d/conda.sh ]; then
source ~/anaconda3/etc/profile.d/conda.sh
elif [ -f /opt/anaconda3/etc/profile.d/conda.sh ]; then
source /opt/anaconda3/etc/profile.d/conda.sh
fi
# 激活 Zeo++ 环境
conda activate {self.config.conda_env}
echo ""
echo "===== 环境检查 ====="
echo "Conda环境: $CONDA_DEFAULT_ENV"
echo "Python路径: $(which python)"
echo "===================="
echo ""
# ============ 读取任务信息 ============
TASKS_FILE="{tasks_file}"
TASK_ID=$SLURM_ARRAY_TASK_ID
# 使用 Python 解析任务
TASK_INFO=$(python3 -c "
import json
with open('$TASKS_FILE', 'r') as f:
tasks = json.load(f)
if $TASK_ID < len(tasks):
task = tasks[$TASK_ID]
print(task['work_dir'])
print(task['yaml_name'])
else:
print('ERROR')
")
WORK_DIR=$(echo "$TASK_INFO" | sed -n '1p')
YAML_NAME=$(echo "$TASK_INFO" | sed -n '2p')
if [ "$WORK_DIR" == "ERROR" ]; then
echo "错误: 任务ID $TASK_ID 超出范围"
exit 1
fi
echo "工作目录: $WORK_DIR"
echo "配置文件: $YAML_NAME"
echo ""
# ============ 执行计算 ============
cd "$WORK_DIR"
echo "开始 Voronoi 分析..."
# 软链接已在工作目录下,直接使用相对路径
# 将输出重定向到 log.txt 以便后续提取结果
python analyze_voronoi_nodes.py *.cif -i "$YAML_NAME" > log.txt 2>&1
EXIT_CODE=$?
# 显示日志内容(用于调试)
echo ""
echo "===== 计算日志 ====="
cat log.txt
echo "===================="
# ============ 完成 ============
echo ""
echo "===== 任务完成 ====="
echo "结束时间: $(date)"
echo "退出代码: $EXIT_CODE"
# 写入状态文件
if [ $EXIT_CODE -eq 0 ]; then
echo "SUCCESS" > "{output_dir}/status_$TASK_ID.txt"
else
echo "FAILED" > "{output_dir}/status_$TASK_ID.txt"
fi
echo "===================="
exit $EXIT_CODE
"""
return script
def _monitor_array_job(
self,
job_id: str,
tasks: List[Dict],
output_dir: str,
desc: str
) -> List[ZeoTaskResult]:
"""监控作业数组进度"""
total = len(tasks)
self.progress_manager = ProgressManager(total, desc)
self.progress_manager.start()
results = [None] * total
completed = set()
print(f"\n监控作业进度 (每 {self.config.poll_interval} 秒检查一次)...")
print("按 Ctrl+C 可中断监控(作业将继续在后台运行)\n")
try:
while len(completed) < total:
time.sleep(self.config.poll_interval)
# 检查状态文件
for i in range(total):
if i in completed:
continue
status_file = os.path.join(output_dir, f"status_{i}.txt")
if os.path.exists(status_file):
with open(status_file, 'r') as f:
status = f.read().strip()
task = tasks[i]
success = (status == "SUCCESS")
# 收集输出文件
output_files = []
if success:
work_dir = task['work_dir']
for f in os.listdir(work_dir):
if f.endswith(('.cif', '.csv')) and f != task['cif_name']:
output_files.append(os.path.join(work_dir, f))
results[i] = ZeoTaskResult(
task_id=i,
structure_name=task.get('structure_name', ''),
cif_path=task.get('cif_path', ''),
success=success,
output_files=output_files
)
completed.add(i)
self.progress_manager.update(success=success)
self.progress_manager.display()
# 检查作业是否还在运行
if not self._is_job_running(job_id) and len(completed) < total:
# 作业已结束但有任务未完成
print(f"\n⚠️ 作业已结束,但有 {total - len(completed)} 个任务未完成")
break
except KeyboardInterrupt:
print("\n\n⚠️ 监控已中断,作业将继续在后台运行")
print(f" 可使用 'squeue -j {job_id}' 查看作业状态")
print(f" 可使用 'scancel {job_id}' 取消作业")
self.progress_manager.finish()
# 填充未完成的任务
for i in range(total):
if results[i] is None:
task = tasks[i]
results[i] = ZeoTaskResult(
task_id=i,
structure_name=task.get('structure_name', ''),
cif_path=task.get('cif_path', ''),
success=False,
error_message="任务未完成或状态未知"
)
return results
def _is_job_running(self, job_id: str) -> bool:
"""检查作业是否还在运行"""
try:
result = subprocess.run(
['squeue', '-j', job_id, '-h'],
capture_output=True,
text=True,
timeout=10
)
return bool(result.stdout.strip())
except Exception:
return False
def print_results_summary(self, results: List[ZeoTaskResult]):
"""打印结果摘要"""
total = len(results)
success = sum(1 for r in results if r.success)
failed = total - success
print("\n" + "=" * 60)
print("【计算结果摘要】")
print("=" * 60)
print(f" 总任务数: {total}")
print(f" 成功: {success} ({100*success/total:.1f}%)")
print(f" 失败: {failed} ({100*failed/total:.1f}%)")
if failed > 0 and failed <= 10:
print("\n 失败的任务:")
for r in results:
if not r.success:
print(f" - {r.structure_name}: {r.error_message}")
elif failed > 10:
print(f"\n 失败任务过多,请检查日志文件")
print("=" * 60)

18
src/core/__init__.py Normal file
View File

@@ -0,0 +1,18 @@
"""
核心模块:调度器、执行器和进度管理
"""
from .scheduler import ParallelScheduler, ResourceConfig, ExecutionMode as SchedulerMode
from .executor import TaskExecutor, ExecutorConfig, ExecutionMode, TaskResult, create_executor
from .progress import ProgressManager
__all__ = [
'ParallelScheduler',
'ResourceConfig',
'SchedulerMode',
'TaskExecutor',
'ExecutorConfig',
'ExecutionMode',
'TaskResult',
'create_executor',
'ProgressManager',
]

0
src/core/controller.py Normal file
View File

431
src/core/executor.py Normal file
View File

@@ -0,0 +1,431 @@
"""
任务执行器:支持本地执行和 SLURM 直接提交
不生成脚本文件,直接在 Python 中管理任务
"""
import os
import subprocess
import time
import json
from typing import List, Callable, Any, Optional, Dict, Tuple
from multiprocessing import Pool, cpu_count
from dataclasses import dataclass, field
from enum import Enum
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading
from .progress import ProgressManager
class ExecutionMode(Enum):
"""执行模式"""
LOCAL = "local" # 本地多进程
SLURM_DIRECT = "slurm" # SLURM 直接提交(不生成脚本)
@dataclass
class ExecutorConfig:
"""执行器配置"""
mode: ExecutionMode = ExecutionMode.LOCAL
max_workers: int = 4
conda_env: str = "/cluster/home/koko125/anaconda3/envs/screen"
partition: str = "cpu"
time_limit: str = "7-00:00:00"
memory_per_task: str = "4G"
# SLURM 相关
poll_interval: float = 2.0 # 轮询间隔(秒)
max_concurrent_jobs: int = 50 # 最大并发作业数
@dataclass
class TaskResult:
"""任务结果"""
task_id: Any
success: bool
result: Any = None
error: str = None
duration: float = 0.0
class TaskExecutor:
"""
任务执行器
支持两种模式:
1. LOCAL: 本地多进程执行
2. SLURM_DIRECT: 直接提交 SLURM 作业,实时监控进度
"""
def __init__(self, config: ExecutorConfig = None):
self.config = config or ExecutorConfig()
self.progress_manager = None
self._stop_event = threading.Event()
@staticmethod
def detect_environment() -> Dict[str, Any]:
"""检测运行环境"""
env_info = {
'hostname': os.uname().nodename,
'total_cores': cpu_count(),
'has_slurm': False,
'slurm_partitions': [],
'conda_env': os.environ.get('CONDA_PREFIX', ''),
}
# 检测 SLURM
try:
result = subprocess.run(
['sinfo', '-h', '-o', '%P %a %c %D'],
capture_output=True, text=True, timeout=5
)
if result.returncode == 0:
env_info['has_slurm'] = True
lines = result.stdout.strip().split('\n')
for line in lines:
parts = line.split()
if len(parts) >= 4:
partition = parts[0].rstrip('*')
avail = parts[1]
if avail == 'up':
env_info['slurm_partitions'].append(partition)
except Exception:
pass
return env_info
def run(
self,
tasks: List[Any],
worker_func: Callable,
desc: str = "Processing"
) -> List[TaskResult]:
"""
执行任务
Args:
tasks: 任务列表
worker_func: 工作函数,接收单个任务,返回结果
desc: 进度条描述
Returns:
TaskResult 列表
"""
if self.config.mode == ExecutionMode.LOCAL:
return self._run_local(tasks, worker_func, desc)
elif self.config.mode == ExecutionMode.SLURM_DIRECT:
return self._run_slurm_direct(tasks, worker_func, desc)
else:
raise ValueError(f"不支持的执行模式: {self.config.mode}")
def _run_local(
self,
tasks: List[Any],
worker_func: Callable,
desc: str
) -> List[TaskResult]:
"""本地多进程执行"""
total = len(tasks)
num_workers = min(self.config.max_workers, total)
print(f"\n{'='*60}")
print(f"本地执行配置:")
print(f" 总任务数: {total}")
print(f" Worker数: {num_workers}")
print(f"{'='*60}\n")
self.progress_manager = ProgressManager(total, desc)
self.progress_manager.start()
results = []
if num_workers == 1:
# 单进程执行
for i, task in enumerate(tasks):
start_time = time.time()
try:
result = worker_func(task)
duration = time.time() - start_time
results.append(TaskResult(
task_id=i,
success=True,
result=result,
duration=duration
))
self.progress_manager.update(success=True)
except Exception as e:
duration = time.time() - start_time
results.append(TaskResult(
task_id=i,
success=False,
error=str(e),
duration=duration
))
self.progress_manager.update(success=False)
self.progress_manager.display()
else:
# 多进程执行
with Pool(processes=num_workers) as pool:
for i, result in enumerate(pool.imap_unordered(worker_func, tasks)):
if result is not None:
results.append(TaskResult(
task_id=i,
success=True,
result=result
))
self.progress_manager.update(success=True)
else:
results.append(TaskResult(
task_id=i,
success=False,
error="Worker returned None"
))
self.progress_manager.update(success=False)
self.progress_manager.display()
self.progress_manager.finish()
return results
def _run_slurm_direct(
self,
tasks: List[Any],
worker_func: Callable,
desc: str
) -> List[TaskResult]:
"""
SLURM 直接提交模式
注意:对于数据库分析这类快速任务,建议使用本地多进程模式
SLURM 模式更适合耗时的计算任务(如 Zeo++ 分析)
这里回退到本地模式,因为 srun 在登录节点直接调用效率不高
"""
print("\n⚠️ 注意:数据库分析阶段自动使用本地多进程模式")
print(" SLURM 模式将在后续耗时计算步骤中使用")
# 回退到本地模式
return self._run_local(tasks, worker_func, desc)
class SlurmJobManager:
"""
SLURM 作业管理器
用于批量提交和监控 SLURM 作业
"""
def __init__(self, config: ExecutorConfig):
self.config = config
self.active_jobs = {} # job_id -> task_info
def submit_batch(
self,
tasks: List[Tuple[str, str, set]], # (file_path, target_cation, target_anions)
output_dir: str,
desc: str = "Processing"
) -> List[TaskResult]:
"""
批量提交任务到 SLURM
使用 sbatch --wrap 直接提交,不生成脚本文件
"""
total = len(tasks)
os.makedirs(output_dir, exist_ok=True)
print(f"\n{'='*60}")
print(f"SLURM 批量提交:")
print(f" 总任务数: {total}")
print(f" 输出目录: {output_dir}")
print(f" Conda环境: {self.config.conda_env}")
print(f"{'='*60}\n")
progress = ProgressManager(total, desc)
progress.start()
results = []
job_ids = []
# 提交所有任务
for i, task in enumerate(tasks):
file_path, target_cation, target_anions = task
# 构建 Python 命令
anions_str = ','.join(target_anions)
python_cmd = (
f"python -c \""
f"import sys; sys.path.insert(0, '{os.getcwd()}'); "
f"from src.analysis.worker import analyze_single_file; "
f"result = analyze_single_file(('{file_path}', '{target_cation}', set('{anions_str}'.split(',')))); "
f"print('SUCCESS' if result and result.is_valid else 'FAILED')"
f"\""
)
# 构建完整的 bash 命令
bash_cmd = (
f"source {os.path.dirname(self.config.conda_env)}/../../etc/profile.d/conda.sh && "
f"conda activate {self.config.conda_env} && "
f"{python_cmd}"
)
# 使用 sbatch --wrap 提交
sbatch_cmd = [
'sbatch',
'--partition', self.config.partition,
'--ntasks', '1',
'--cpus-per-task', '1',
'--mem', self.config.memory_per_task,
'--time', '01:00:00',
'--output', os.path.join(output_dir, f'task_{i}.out'),
'--error', os.path.join(output_dir, f'task_{i}.err'),
'--wrap', bash_cmd
]
try:
result = subprocess.run(
sbatch_cmd,
capture_output=True,
text=True
)
if result.returncode == 0:
# 提取 job_id
job_id = result.stdout.strip().split()[-1]
job_ids.append((i, job_id, file_path))
self.active_jobs[job_id] = {
'task_index': i,
'file_path': file_path,
'status': 'PENDING'
}
else:
results.append(TaskResult(
task_id=i,
success=False,
error=f"提交失败: {result.stderr}"
))
progress.update(success=False)
progress.display()
except Exception as e:
results.append(TaskResult(
task_id=i,
success=False,
error=str(e)
))
progress.update(success=False)
progress.display()
print(f"\n已提交 {len(job_ids)} 个作业,等待完成...")
# 监控作业状态
while self.active_jobs:
time.sleep(self.config.poll_interval)
# 检查作业状态
completed_jobs = self._check_job_status()
for job_id, status in completed_jobs:
job_info = self.active_jobs.pop(job_id, None)
if job_info:
task_idx = job_info['task_index']
if status == 'COMPLETED':
# 检查输出文件
out_file = os.path.join(output_dir, f'task_{task_idx}.out')
success = False
if os.path.exists(out_file):
with open(out_file, 'r') as f:
content = f.read()
success = 'SUCCESS' in content
results.append(TaskResult(
task_id=task_idx,
success=success,
result=job_info['file_path']
))
progress.update(success=success)
else:
# 作业失败
err_file = os.path.join(output_dir, f'task_{task_idx}.err')
error_msg = status
if os.path.exists(err_file):
with open(err_file, 'r') as f:
error_msg = f.read()[:500] # 只取前500字符
results.append(TaskResult(
task_id=task_idx,
success=False,
error=error_msg
))
progress.update(success=False)
progress.display()
progress.finish()
return results
def _check_job_status(self) -> List[Tuple[str, str]]:
"""检查作业状态,返回已完成的作业列表"""
if not self.active_jobs:
return []
job_ids = list(self.active_jobs.keys())
try:
result = subprocess.run(
['sacct', '-j', ','.join(job_ids), '--format=JobID,State', '--noheader', '--parsable2'],
capture_output=True,
text=True,
timeout=30
)
completed = []
if result.returncode == 0:
for line in result.stdout.strip().split('\n'):
if line:
parts = line.split('|')
if len(parts) >= 2:
job_id = parts[0].split('.')[0] # 去掉 .batch 后缀
status = parts[1]
if job_id in self.active_jobs:
if status in ['COMPLETED', 'FAILED', 'CANCELLED', 'TIMEOUT', 'NODE_FAIL']:
completed.append((job_id, status))
return completed
except Exception:
return []
def create_executor(
mode: str = "local",
max_workers: int = None,
conda_env: str = None,
**kwargs
) -> TaskExecutor:
"""
创建任务执行器的便捷函数
Args:
mode: "local""slurm"
max_workers: 最大工作进程数
conda_env: Conda 环境路径
**kwargs: 其他配置参数
"""
env = TaskExecutor.detect_environment()
if max_workers is None:
max_workers = min(env['total_cores'], 32)
if conda_env is None:
conda_env = env.get('conda_env') or "/cluster/home/koko125/anaconda3/envs/screen"
exec_mode = ExecutionMode.SLURM_DIRECT if mode.lower() == "slurm" else ExecutionMode.LOCAL
config = ExecutorConfig(
mode=exec_mode,
max_workers=max_workers,
conda_env=conda_env,
**kwargs
)
return TaskExecutor(config)

115
src/core/progress.py Normal file
View File

@@ -0,0 +1,115 @@
"""
进度管理器:支持多进程的实时进度显示
"""
import os
import sys
import time
from multiprocessing import Manager, Value
from typing import Optional
from datetime import datetime, timedelta
class ProgressManager:
"""多进程安全的进度管理器"""
def __init__(self, total: int, desc: str = "Processing"):
self.total = total
self.desc = desc
self.manager = Manager()
self.completed = self.manager.Value('i', 0)
self.failed = self.manager.Value('i', 0)
self.start_time = None
self._lock = self.manager.Lock()
def start(self):
"""开始计时"""
self.start_time = time.time()
def update(self, success: bool = True):
"""更新进度(进程安全)"""
with self._lock:
if success:
self.completed.value += 1
else:
self.failed.value += 1
def get_progress(self) -> dict:
"""获取当前进度"""
completed = self.completed.value
failed = self.failed.value
total_done = completed + failed
elapsed = time.time() - self.start_time if self.start_time else 0
if total_done > 0:
speed = total_done / elapsed # items/sec
remaining = (self.total - total_done) / speed if speed > 0 else 0
else:
speed = 0
remaining = 0
return {
'total': self.total,
'completed': completed,
'failed': failed,
'total_done': total_done,
'percent': total_done / self.total * 100 if self.total > 0 else 0,
'elapsed': elapsed,
'remaining': remaining,
'speed': speed
}
def display(self):
"""显示进度条"""
p = self.get_progress()
# 进度条
bar_width = 40
filled = int(bar_width * p['total_done'] / p['total']) if p['total'] > 0 else 0
bar = '' * filled + '' * (bar_width - filled)
# 时间格式化
elapsed_str = str(timedelta(seconds=int(p['elapsed'])))
remaining_str = str(timedelta(seconds=int(p['remaining'])))
# 构建显示字符串
status = (
f"\r{self.desc}: |{bar}| "
f"{p['total_done']}/{p['total']} ({p['percent']:.1f}%) "
f"[{elapsed_str}<{remaining_str}, {p['speed']:.1f}it/s] "
f"{p['completed']}{p['failed']}"
)
sys.stdout.write(status)
sys.stdout.flush()
def finish(self):
"""完成显示"""
self.display()
print() # 换行
class ProgressReporter:
"""进度报告器(用于后台监控)"""
def __init__(self, progress_file: str = ".progress"):
self.progress_file = progress_file
def save(self, progress: dict):
"""保存进度到文件"""
import json
with open(self.progress_file, 'w') as f:
json.dump(progress, f)
def load(self) -> Optional[dict]:
"""从文件加载进度"""
import json
if os.path.exists(self.progress_file):
with open(self.progress_file, 'r') as f:
return json.load(f)
return None
def cleanup(self):
"""清理进度文件"""
if os.path.exists(self.progress_file):
os.remove(self.progress_file)

291
src/core/scheduler.py Normal file
View File

@@ -0,0 +1,291 @@
"""
并行调度器支持本地多进程和SLURM集群调度
"""
import os
import subprocess
import tempfile
from typing import List, Callable, Any, Optional, Dict
from multiprocessing import Pool, cpu_count
from dataclasses import dataclass
from enum import Enum
import math
from .progress import ProgressManager
class ExecutionMode(Enum):
"""执行模式"""
LOCAL_SINGLE = "local_single"
LOCAL_MULTI = "local_multi"
SLURM_SINGLE = "slurm_single"
SLURM_ARRAY = "slurm_array"
@dataclass
class ResourceConfig:
"""资源配置"""
max_cores: int = 4
cores_per_worker: int = 1
memory_per_core: str = "4G"
partition: str = "cpu"
time_limit: str = "7-00:00:00"
conda_env_path: str = "~/anaconda3/envs/screen" # 新增Conda环境路径
@property
def num_workers(self) -> int:
return max(1, self.max_cores // self.cores_per_worker)
class ParallelScheduler:
"""并行调度器"""
COMPLEXITY_CORES = {
'low': 1,
'medium': 2,
'high': 4,
}
def __init__(self, resource_config: ResourceConfig = None):
self.config = resource_config or ResourceConfig()
self.progress_manager = None
@staticmethod
def detect_environment() -> Dict[str, Any]:
"""检测运行环境"""
env_info = {
'hostname': os.uname().nodename,
'total_cores': cpu_count(),
'has_slurm': False,
'slurm_partitions': [],
'available_nodes': 0,
'conda_prefix': os.environ.get('CONDA_PREFIX', ''),
}
# 检测SLURM
try:
result = subprocess.run(
['sinfo', '-h', '-o', '%P %a %c %D'],
capture_output=True, text=True, timeout=5
)
if result.returncode == 0:
env_info['has_slurm'] = True
lines = result.stdout.strip().split('\n')
for line in lines:
parts = line.split()
if len(parts) >= 4:
partition = parts[0].rstrip('*')
avail = parts[1]
cores = int(parts[2])
nodes = int(parts[3])
if avail == 'up':
env_info['slurm_partitions'].append({
'name': partition,
'cores_per_node': cores,
'nodes': nodes
})
env_info['available_nodes'] += nodes
except Exception:
pass
return env_info
@staticmethod
def recommend_config(
num_tasks: int,
task_complexity: str = 'medium',
max_cores: int = None,
conda_env_path: str = None
) -> ResourceConfig:
"""根据任务量和复杂度推荐配置"""
env = ParallelScheduler.detect_environment()
if max_cores is None:
max_cores = min(env['total_cores'], 32)
cores_per_worker = ParallelScheduler.COMPLEXITY_CORES.get(task_complexity, 2)
max_workers = max_cores // cores_per_worker
optimal_workers = min(num_tasks, max_workers)
actual_cores = optimal_workers * cores_per_worker
# 自动检测Conda环境路径
if conda_env_path is None:
conda_env_path = env.get('conda_prefix', '~/anaconda3/envs/screen')
config = ResourceConfig(
max_cores=actual_cores,
cores_per_worker=cores_per_worker,
conda_env_path=conda_env_path,
)
return config
def run_local(
self,
tasks: List[Any],
worker_func: Callable,
desc: str = "Processing"
) -> List[Any]:
"""本地多进程执行"""
num_workers = self.config.num_workers
total = len(tasks)
print(f"\n{'='*60}")
print(f"并行配置:")
print(f" 总任务数: {total}")
print(f" Worker数: {num_workers}")
print(f" 每Worker核数: {self.config.cores_per_worker}")
print(f" 总使用核数: {self.config.max_cores}")
print(f"{'='*60}\n")
self.progress_manager = ProgressManager(total, desc)
self.progress_manager.start()
results = []
if num_workers == 1:
for task in tasks:
try:
result = worker_func(task)
results.append(result)
self.progress_manager.update(success=True)
except Exception as e:
results.append(None)
self.progress_manager.update(success=False)
self.progress_manager.display()
else:
with Pool(processes=num_workers) as pool:
for result in pool.imap_unordered(worker_func, tasks, chunksize=10):
if result is not None:
results.append(result)
self.progress_manager.update(success=True)
else:
self.progress_manager.update(success=False)
self.progress_manager.display()
self.progress_manager.finish()
return results
def generate_slurm_script(
self,
tasks_file: str,
worker_script: str,
output_dir: str,
job_name: str = "analysis"
) -> str:
"""生成SLURM作业脚本修复版"""
# 获取当前工作目录的绝对路径
submit_dir = os.getcwd()
# 转换为绝对路径
abs_tasks_file = os.path.abspath(tasks_file)
abs_worker_script = os.path.abspath(worker_script)
abs_output_dir = os.path.abspath(output_dir)
script = f"""#!/bin/bash
#SBATCH --job-name={job_name}
#SBATCH --partition={self.config.partition}
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH --cpus-per-task={self.config.max_cores}
#SBATCH --mem-per-cpu={self.config.memory_per_core}
#SBATCH --time={self.config.time_limit}
#SBATCH --output={abs_output_dir}/slurm_%j.log
#SBATCH --error={abs_output_dir}/slurm_%j.err
# ============================================
# SLURM作业脚本 - 自动生成
# ============================================
echo "===== 作业信息 ====="
echo "作业ID: $SLURM_JOB_ID"
echo "节点: $SLURM_NODELIST"
echo "CPU数: $SLURM_CPUS_PER_TASK"
echo "开始时间: $(date)"
echo "===================="
# ============ 环境初始化 ============
# 关键确保bashrc被加载
if [ -f ~/.bashrc ]; then
source ~/.bashrc
fi
# 初始化Conda
if [ -f ~/anaconda3/etc/profile.d/conda.sh ]; then
source ~/anaconda3/etc/profile.d/conda.sh
elif [ -f /opt/anaconda3/etc/profile.d/conda.sh ]; then
source /opt/anaconda3/etc/profile.d/conda.sh
else
echo "错误: 找不到conda.sh"
exit 1
fi
# 激活环境 (使用完整路径)
conda activate {self.config.conda_env_path}
# 验证环境
echo ""
echo "===== 环境检查 ====="
echo "Conda环境: $CONDA_DEFAULT_ENV"
echo "Python路径: $(which python)"
echo "Python版本: $(python --version 2>&1)"
python -c "import pymatgen; print(f'pymatgen版本: {{pymatgen.__version__}}')" 2>/dev/null || echo "警告: pymatgen未安装"
echo "===================="
echo ""
# 设置工作目录
cd {submit_dir}
export PYTHONPATH={submit_dir}:$PYTHONPATH
echo "工作目录: $(pwd)"
echo "PYTHONPATH: $PYTHONPATH"
echo ""
# ============ 运行分析 ============
echo "开始执行分析任务..."
python {abs_worker_script} \\
--tasks-file {abs_tasks_file} \\
--output-dir {abs_output_dir} \\
--num-workers {self.config.num_workers}
EXIT_CODE=$?
# ============ 完成 ============
echo ""
echo "===== 作业完成 ====="
echo "结束时间: $(date)"
echo "退出代码: $EXIT_CODE"
echo "===================="
exit $EXIT_CODE
"""
return script
def submit_slurm_job(self, script_content: str, script_path: str = None) -> str:
"""提交SLURM作业"""
if script_path is None:
fd, script_path = tempfile.mkstemp(suffix='.sh')
os.close(fd)
with open(script_path, 'w') as f:
f.write(script_content)
os.chmod(script_path, 0o755)
# 打印脚本内容用于调试
print(f"\n生成的SLURM脚本保存到: {script_path}")
result = subprocess.run(
['sbatch', script_path],
capture_output=True, text=True
)
if result.returncode == 0:
job_id = result.stdout.strip().split()[-1]
return job_id
else:
raise RuntimeError(f"SLURM提交失败:\nstdout: {result.stdout}\nstderr: {result.stderr}")

View File

View File

View File

View File

@@ -0,0 +1,562 @@
"""
结构预处理器:扩胞和添加化合价
"""
import os
import re
import yaml
from typing import List, Dict, Optional, Tuple
from dataclasses import dataclass, field
from pymatgen.core.structure import Structure
from pymatgen.core.periodic_table import Specie
from pymatgen.core import Lattice, Species, PeriodicSite
from collections import defaultdict
from fractions import Fraction
from functools import reduce
import math
import random
import spglib
import numpy as np
@dataclass
class ProcessingResult:
"""处理结果"""
input_file: str
output_files: List[str] = field(default_factory=list)
success: bool = False
needs_expansion: bool = False
expansion_factor: int = 1
error_message: str = ""
class StructureProcessor:
"""结构预处理器"""
# 默认化合价配置
DEFAULT_VALENCE_PATH = os.path.join(
os.path.dirname(__file__), '..', '..', 'tool', 'valence_states.yaml'
)
def __init__(
self,
valence_yaml_path: str = None,
calculate_type: str = 'low',
max_expansion_factor: int = 64,
keep_number: int = 3,
target_cation: str = "Li"
):
"""
初始化处理器
Args:
valence_yaml_path: 化合价配置文件路径
calculate_type: 扩胞计算精度 ('high', 'normal', 'low', 'very_low')
max_expansion_factor: 最大扩胞因子
keep_number: 保留的扩胞结构数量
target_cation: 目标阳离子
"""
self.valence_yaml_path = valence_yaml_path or self.DEFAULT_VALENCE_PATH
self.calculate_type = calculate_type
self.max_expansion_factor = max_expansion_factor
self.keep_number = keep_number
self.target_cation = target_cation
self.explict_element = [target_cation, f"{target_cation}+"]
# 加载化合价配置
self.valences = self._load_valences()
def _load_valences(self) -> Dict[str, int]:
"""加载化合价配置"""
if os.path.exists(self.valence_yaml_path):
with open(self.valence_yaml_path, 'r') as f:
return yaml.safe_load(f)
return {}
def process_file(
self,
input_path: str,
output_dir: str,
needs_expansion: bool = False
) -> ProcessingResult:
"""
处理单个CIF文件
Args:
input_path: 输入文件路径
output_dir: 输出目录
needs_expansion: 是否需要扩胞
Returns:
ProcessingResult: 处理结果
"""
result = ProcessingResult(input_file=input_path)
try:
# 读取结构
structure = Structure.from_file(input_path)
base_name = os.path.splitext(os.path.basename(input_path))[0]
# 检查是否需要扩胞
occupation_list = self._process_cif_file(structure)
if occupation_list and needs_expansion:
# 需要扩胞处理
result.needs_expansion = True
output_files = self._expand_and_save(
structure, occupation_list, base_name, output_dir
)
result.output_files = output_files
result.expansion_factor = occupation_list[0].get('denominator', 1) if occupation_list else 1
else:
# 不需要扩胞,直接添加化合价
output_path = os.path.join(output_dir, f"{base_name}.cif")
self._add_oxidation_states(structure)
structure.to(filename=output_path)
result.output_files = [output_path]
result.success = True
except Exception as e:
result.success = False
result.error_message = str(e)
return result
def _process_cif_file(self, structure: Structure) -> List[Dict]:
"""
统计结构中各原子的occupation情况
"""
occupation_dict = defaultdict(list)
split_dict = {}
for i, site in enumerate(structure):
occu = self._get_occu(site.species_string)
if occu != 1.0:
if site.species.chemical_system not in self.explict_element:
occupation_dict[occu].append(i + 1)
# 提取元素名称列表
elements = []
if ':' in site.species_string:
parts = site.species_string.split(',')
for part in parts:
element_with_valence = part.strip().split(':')[0].strip()
element_match = re.match(r'([A-Z][a-z]?)', element_with_valence)
if element_match:
elements.append(element_match.group(1))
else:
element_match = re.match(r'([A-Z][a-z]?)', site.species_string)
if element_match:
elements = [element_match.group(1)]
split_dict[occu] = elements
# 转换为列表格式
occupation_list = [
{
"occupation": occu,
"atom_serial": serials,
"numerator": None,
"denominator": None,
"split": split_dict.get(occu, [])
}
for occu, serials in occupation_dict.items()
]
return occupation_list
def _get_occu(self, s_str: str) -> float:
"""从物种字符串获取占据率"""
if not s_str.strip():
return 1.0
pattern = r'([A-Za-z0-9+-]+):([0-9.]+)'
matches = re.findall(pattern, s_str)
for species, occu in matches:
if species not in self.explict_element:
try:
return float(occu)
except ValueError:
continue
return 1.0
def _calculate_expansion_factor(self, occupation_list: List[Dict]) -> Tuple[int, List[Dict]]:
"""计算扩胞因子"""
if not occupation_list:
return 1, []
precision_limits = {
'high': None,
'normal': 100,
'low': 10,
'very_low': 5
}
limit = precision_limits.get(self.calculate_type)
for entry in occupation_list:
occu = entry["occupation"]
if limit:
fraction = Fraction(occu).limit_denominator(limit)
else:
fraction = Fraction(occu).limit_denominator()
entry["numerator"] = fraction.numerator
entry["denominator"] = fraction.denominator
# 计算最小公倍数
denominators = [entry["denominator"] for entry in occupation_list]
lcm = reduce(lambda a, b: a * b // math.gcd(a, b), denominators, 1)
# 统一分母
for entry in occupation_list:
denominator = entry["denominator"]
entry["numerator"] = entry["numerator"] * (lcm // denominator)
entry["denominator"] = lcm
return lcm, occupation_list
def _expand_and_save(
self,
structure: Structure,
occupation_list: List[Dict],
base_name: str,
output_dir: str
) -> List[str]:
"""扩胞并保存"""
lcm, oc_list = self._calculate_expansion_factor(occupation_list)
if lcm > self.max_expansion_factor:
raise ValueError(f"扩胞因子 {lcm} 超过最大限制 {self.max_expansion_factor}")
# 获取扩胞策略
strategies = self._strategy_divide(structure, lcm)
if not strategies:
raise ValueError("无法找到合适的扩胞策略")
# 生成结构列表
st_list = self._generate_structure_list(structure, oc_list)
output_files = []
keep_number = min(self.keep_number, len(strategies))
for index in range(keep_number):
merged = self._merge_structures(st_list, strategies[index])
# 添加化合价
self._add_oxidation_states(merged)
# 当只保存1个时不加后缀
if keep_number == 1:
output_filename = f"{base_name}.cif"
else:
suffix = "x{}y{}z{}".format(
strategies[index]["x"],
strategies[index]["y"],
strategies[index]["z"]
)
output_filename = f"{base_name}-{suffix}.cif"
output_path = os.path.join(output_dir, output_filename)
merged.to(filename=output_path, fmt="cif")
output_files.append(output_path)
return output_files
def _add_oxidation_states(self, structure: Structure):
"""添加化合价"""
# 检查是否已有化合价
has_oxidation = all(
all(isinstance(sp, Specie) for sp in site.species.keys())
for site in structure.sites
)
if not has_oxidation and self.valences:
structure.add_oxidation_state_by_element(self.valences)
def _strategy_divide(self, structure: Structure, total: int) -> List[Dict]:
"""根据晶体类型确定扩胞策略"""
try:
space_group_info = structure.get_space_group_info()
space_group_symbol = space_group_info[0]
# 获取空间群类型
all_spacegroup_symbols = [spglib.get_spacegroup_type(i) for i in range(1, 531)]
symbol = all_spacegroup_symbols[0]
for symbol_i in all_spacegroup_symbols:
if space_group_symbol == symbol_i.international_short:
symbol = symbol_i
break
space_type = self._typejudge(symbol.number)
if space_type == "Cubic":
return self._factorize_to_three_factors(total, "xyz")
else:
return self._factorize_to_three_factors(total)
except:
return self._factorize_to_three_factors(total)
def _typejudge(self, number: int) -> str:
"""判断晶体类型"""
if number in [1, 2]:
return "Triclinic"
elif 3 <= number <= 15:
return "Monoclinic"
elif 16 <= number <= 74:
return "Orthorhombic"
elif 75 <= number <= 142:
return "Tetragonal"
elif 143 <= number <= 167:
return "Trigonal"
elif 168 <= number <= 194:
return "Hexagonal"
elif 195 <= number <= 230:
return "Cubic"
else:
return "Unknown"
def _factorize_to_three_factors(self, n: int, type_sym: str = None) -> List[Dict]:
"""分解为三个因子"""
factors = []
if type_sym == "xyz":
for x in range(1, n + 1):
if n % x == 0:
remaining_n = n // x
for y in range(1, remaining_n + 1):
if remaining_n % y == 0 and y <= x:
z = remaining_n // y
if z <= y:
factors.append({'x': x, 'y': y, 'z': z})
else:
for x in range(1, n + 1):
if n % x == 0:
remaining_n = n // x
for y in range(1, remaining_n + 1):
if remaining_n % y == 0:
z = remaining_n // y
factors.append({'x': x, 'y': y, 'z': z})
# 排序
def sort_key(item):
return (item['x'] + item['y'] + item['z'], item['z'], item['y'], item['x'])
return sorted(factors, key=sort_key)
def _generate_structure_list(
self,
base_structure: Structure,
occupation_list: List[Dict]
) -> List[Structure]:
"""生成结构列表"""
if not occupation_list:
return [base_structure.copy()]
lcm = occupation_list[0]["denominator"]
structure_list = [base_structure.copy() for _ in range(lcm)]
for entry in occupation_list:
numerator = entry["numerator"]
denominator = entry["denominator"]
atom_indices = entry["atom_serial"]
for atom_idx in atom_indices:
occupancy_dict = self._mark_atoms_randomly(numerator, denominator)
original_site = base_structure.sites[atom_idx - 1]
element = self._get_first_non_explicit_element(original_site.species_string)
for copy_idx, occupy in occupancy_dict.items():
structure_list[copy_idx].remove_sites([atom_idx - 1])
oxi_state = self._extract_oxi_state(original_site.species_string, element)
if len(entry["split"]) == 1:
if occupy:
new_site = PeriodicSite(
species=Species(element, oxi_state),
coords=original_site.frac_coords,
lattice=structure_list[copy_idx].lattice,
to_unit_cell=True,
label=original_site.label
)
else:
species_dict = {Species(self.target_cation, 1.0): 0.0}
new_site = PeriodicSite(
species=species_dict,
coords=original_site.frac_coords,
lattice=structure_list[copy_idx].lattice,
to_unit_cell=True,
label=original_site.label
)
else:
if occupy:
new_site = PeriodicSite(
species=Species(element, oxi_state),
coords=original_site.frac_coords,
lattice=structure_list[copy_idx].lattice,
to_unit_cell=True,
label=original_site.label
)
else:
new_site = PeriodicSite(
species=Species(entry['split'][1], oxi_state),
coords=original_site.frac_coords,
lattice=structure_list[copy_idx].lattice,
to_unit_cell=True,
label=original_site.label
)
structure_list[copy_idx].sites.insert(atom_idx - 1, new_site)
return structure_list
def _mark_atoms_randomly(self, numerator: int, denominator: int) -> Dict[int, int]:
"""随机标记原子"""
if numerator > denominator:
raise ValueError(f"numerator ({numerator}) 不能超过 denominator ({denominator})")
atom_dice = list(range(denominator))
selected_atoms = random.sample(atom_dice, numerator)
return {atom: 1 if atom in selected_atoms else 0 for atom in atom_dice}
def _get_first_non_explicit_element(self, species_str: str) -> str:
"""获取第一个非目标元素"""
if not species_str.strip():
return ""
species_parts = [part.strip() for part in species_str.split(",") if part.strip()]
for part in species_parts:
element_with_charge = part.split(":")[0].strip()
pure_element = ''.join([c for c in element_with_charge if c.isalpha()])
if pure_element not in self.explict_element:
return pure_element
return ""
def _extract_oxi_state(self, species_str: str, element: str) -> int:
"""提取氧化态"""
species_parts = [part.strip() for part in species_str.split(",") if part.strip()]
for part in species_parts:
element_with_charge = part.split(":")[0].strip()
if element in element_with_charge:
charge_part = element_with_charge[len(element):]
if not any(c.isdigit() for c in charge_part):
if "+" in charge_part:
return 1
elif "-" in charge_part:
return -1
else:
return 0
sign = 1
if "-" in charge_part:
sign = -1
digits = ""
for c in charge_part:
if c.isdigit():
digits += c
if digits:
return sign * int(digits)
return 0
def _merge_structures(self, structure_list: List[Structure], merge_dict: Dict) -> Structure:
"""合并结构"""
if not structure_list:
raise ValueError("结构列表不能为空")
ref_lattice = structure_list[0].lattice
total_merge = merge_dict.get("x", 1) * merge_dict.get("y", 1) * merge_dict.get("z", 1)
if len(structure_list) != total_merge:
raise ValueError(f"结构数量({len(structure_list)})与合并次数({total_merge})不匹配")
a, b, c = ref_lattice.abc
alpha, beta, gamma = ref_lattice.angles
new_a = a * merge_dict.get("x", 1)
new_b = b * merge_dict.get("y", 1)
new_c = c * merge_dict.get("z", 1)
new_lattice = Lattice.from_parameters(new_a, new_b, new_c, alpha, beta, gamma)
all_sites = []
for i, structure in enumerate(structure_list):
x_offset = (i // (merge_dict.get("y", 1) * merge_dict.get("z", 1))) % merge_dict.get("x", 1)
y_offset = (i // merge_dict.get("z", 1)) % merge_dict.get("y", 1)
z_offset = i % merge_dict.get("z", 1)
for site in structure:
coords = site.frac_coords.copy()
coords[0] = (coords[0] + x_offset) / merge_dict.get("x", 1)
coords[1] = (coords[1] + y_offset) / merge_dict.get("y", 1)
coords[2] = (coords[2] + z_offset) / merge_dict.get("z", 1)
all_sites.append({"species": site.species, "coords": coords})
return Structure(
new_lattice,
[site["species"] for site in all_sites],
[site["coords"] for site in all_sites]
)
def process_batch(
input_files: List[str],
output_dir: str,
needs_expansion_flags: List[bool] = None,
valence_yaml_path: str = None,
calculate_type: str = 'low',
target_cation: str = "Li",
show_progress: bool = True
) -> List[ProcessingResult]:
"""
批量处理CIF文件
Args:
input_files: 输入文件列表
output_dir: 输出目录
needs_expansion_flags: 是否需要扩胞的标记列表
valence_yaml_path: 化合价配置文件路径
calculate_type: 扩胞计算精度
target_cation: 目标阳离子
show_progress: 是否显示进度
Returns:
处理结果列表
"""
os.makedirs(output_dir, exist_ok=True)
processor = StructureProcessor(
valence_yaml_path=valence_yaml_path,
calculate_type=calculate_type,
target_cation=target_cation
)
if needs_expansion_flags is None:
needs_expansion_flags = [False] * len(input_files)
results = []
total = len(input_files)
for i, (input_file, needs_exp) in enumerate(zip(input_files, needs_expansion_flags)):
if show_progress:
print(f"\r处理进度: {i+1}/{total} - {os.path.basename(input_file)}", end="")
result = processor.process_file(input_file, output_dir, needs_exp)
results.append(result)
if show_progress:
print()
return results

View File

View File

0
src/utils/__init__.py Normal file
View File

0
src/utils/io.py Normal file
View File

0
src/utils/logger.py Normal file
View File

0
src/utils/structure.py Normal file
View File

5
tool/Li/Br+O.yaml Normal file
View File

@@ -0,0 +1,5 @@
SPECIE: Li+
ANION: Br
PERCO_R: 0.45
NEIGHBOR: 1.8
LONG: 2.2

5
tool/Li/Cl+O.yaml Normal file
View File

@@ -0,0 +1,5 @@
SPECIE: Li+
ANION: Cl
PERCO_R: 0.45
NEIGHBOR: 1.8
LONG: 2.2