diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..5f987aa
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,21 @@
+# --- Python 通用忽略 ---
+__pycache__/
+*.pyc
+*.pyo
+*.pyd
+.Python
+env/
+venv/
+.env
+.venv/
+
+# --- VS Code 配置 (可选,建议忽略) ---
+.vscode/
+
+# --- JetBrains (你之前的配置) ---
+.idea/
+/shelf/
+/workspace.xml
+/httpRequests/
+/dataSources/
+/dataSources.local.xml
\ No newline at end of file
diff --git a/.idea/.gitignore b/.idea/.gitignore
deleted file mode 100644
index 35410ca..0000000
--- a/.idea/.gitignore
+++ /dev/null
@@ -1,8 +0,0 @@
-# 默认忽略的文件
-/shelf/
-/workspace.xml
-# 基于编辑器的 HTTP 客户端请求
-/httpRequests/
-# Datasource local storage ignored files
-/dataSources/
-/dataSources.local.xml
diff --git a/.idea/Screen.iml b/.idea/Screen.iml
deleted file mode 100644
index 909438d..0000000
--- a/.idea/Screen.iml
+++ /dev/null
@@ -1,8 +0,0 @@
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
deleted file mode 100644
index bebfaed..0000000
--- a/.idea/inspectionProfiles/Project_Default.xml
+++ /dev/null
@@ -1,23 +0,0 @@
-
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
deleted file mode 100644
index 105ce2d..0000000
--- a/.idea/inspectionProfiles/profiles_settings.xml
+++ /dev/null
@@ -1,6 +0,0 @@
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
deleted file mode 100644
index 287dd8e..0000000
--- a/.idea/modules.xml
+++ /dev/null
@@ -1,8 +0,0 @@
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
deleted file mode 100644
index 94a25f7..0000000
--- a/.idea/vcs.xml
+++ /dev/null
@@ -1,6 +0,0 @@
-
-
-
-
-
-
\ No newline at end of file
diff --git a/main.py b/main.py
index 4fce51d..9379b05 100644
--- a/main.py
+++ b/main.py
@@ -1,45 +1,188 @@
"""
-高通量筛选与扩胞项目 - 主入口(支持并行)
+高通量筛选与扩胞项目 - 主入口(支持断点续做)
"""
import os
import sys
+import json
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
from src.analysis.database_analyzer import DatabaseAnalyzer
from src.analysis.report_generator import ReportGenerator
-from src.core.scheduler import ParallelScheduler
+from src.core.executor import TaskExecutor
+from src.preprocessing.processor import StructureProcessor
+from src.computation.workspace_manager import WorkspaceManager
+from src.computation.zeo_executor import ZeoExecutor, ZeoConfig
+from src.computation.result_processor import ResultProcessor, FilterCriteria
def print_banner():
print("""
╔═══════════════════════════════════════════════════════════════════╗
-║ 高通量筛选与扩胞项目 - 数据库分析工具 v2.0 ║
-║ 支持高性能并行计算 ║
+║ 高通量筛选与扩胞项目 - 数据库分析工具 v2.2 ║
+║ 支持断点续做与高性能并行计算 ║
╚═══════════════════════════════════════════════════════════════════╝
""")
def detect_and_show_environment():
"""检测并显示环境信息"""
- env = ParallelScheduler.detect_environment()
+ env = TaskExecutor.detect_environment()
print("【运行环境检测】")
print(f" 主机名: {env['hostname']}")
print(f" 本地CPU核数: {env['total_cores']}")
print(f" SLURM集群: {'✅ 可用' if env['has_slurm'] else '❌ 不可用'}")
-
+
if env['has_slurm'] and env['slurm_partitions']:
- print(f" 可用分区:")
- for p in env['slurm_partitions']:
- print(f" - {p['name']}: {p['nodes']}节点, {p['cores_per_node']}核/节点")
+ print(f" 可用分区: {', '.join(env['slurm_partitions'])}")
+
+ if env['conda_env']:
+ print(f" 当前Conda: {env['conda_env']}")
return env
-def get_user_input(env: dict):
- """获取用户输入"""
+def detect_workflow_status(workspace_path: str = "workspace", target_cation: str = "Li") -> dict:
+ """
+ 检测工作流程状态,确定可以从哪一步继续
+
+ Returns:
+ dict: {
+ 'has_processed_data': bool, # 是否有扩胞处理后的数据
+ 'has_zeo_results': bool, # 是否有 Zeo++ 计算结果
+ 'total_structures': int, # 总结构数
+ 'structures_with_log': int, # 有 log.txt 的结构数
+ 'workspace_info': object, # 工作区信息
+ 'ws_manager': object # 工作区管理器
+ }
+ """
+ status = {
+ 'has_processed_data': False,
+ 'has_zeo_results': False,
+ 'total_structures': 0,
+ 'structures_with_log': 0,
+ 'workspace_info': None,
+ 'ws_manager': None
+ }
+
+ # 创建工作区管理器
+ ws_manager = WorkspaceManager(
+ workspace_path=workspace_path,
+ tool_dir="tool",
+ target_cation=target_cation
+ )
+ status['ws_manager'] = ws_manager
+
+ # 检查是否有处理后的数据
+ data_dir = os.path.join(workspace_path, "data")
+ if os.path.exists(data_dir):
+ existing = ws_manager.check_existing_workspace()
+ if existing and existing.total_structures > 0:
+ status['has_processed_data'] = True
+ status['total_structures'] = existing.total_structures
+ status['workspace_info'] = existing
+
+ # 检查有多少结构有 log.txt(即已完成 Zeo++ 计算)
+ log_count = 0
+ for anion_key in os.listdir(data_dir):
+ anion_dir = os.path.join(data_dir, anion_key)
+ if not os.path.isdir(anion_dir):
+ continue
+ for struct_name in os.listdir(anion_dir):
+ struct_dir = os.path.join(anion_dir, struct_name)
+ if os.path.isdir(struct_dir):
+ log_path = os.path.join(struct_dir, "log.txt")
+ if os.path.exists(log_path):
+ log_count += 1
+
+ status['structures_with_log'] = log_count
+
+ # 如果大部分结构都有 log.txt,认为 Zeo++ 计算已完成
+ if log_count > 0 and log_count >= status['total_structures'] * 0.5:
+ status['has_zeo_results'] = True
+
+ return status
+
+def print_workflow_status(status: dict):
+ """打印工作流程状态"""
+ print("\n" + "─" * 50)
+ print("【工作流程状态检测】")
+ print("─" * 50)
+
+ if not status['has_processed_data']:
+ print(" Step 1 (扩胞+化合价): ❌ 未完成")
+ print(" Step 2-4 (Zeo++ 计算): ❌ 未完成")
+ print(" Step 5 (结果筛选): ❌ 未完成")
+ else:
+ print(f" Step 1 (扩胞+化合价): ✅ 已完成 ({status['total_structures']} 个结构)")
+
+ if status['has_zeo_results']:
+ print(f" Step 2-4 (Zeo++ 计算): ✅ 已完成 ({status['structures_with_log']}/{status['total_structures']} 有日志)")
+ print(" Step 5 (结果筛选): ⏳ 可执行")
+ else:
+ print(f" Step 2-4 (Zeo++ 计算): ⏳ 可执行 ({status['structures_with_log']}/{status['total_structures']} 有日志)")
+ print(" Step 5 (结果筛选): ❌ 需先完成 Zeo++ 计算")
+
+ print("─" * 50)
+
+
+def get_user_choice(status: dict) -> str:
+ """
+ 根据工作流程状态获取用户选择
+
+ Returns:
+ 'step1': 从头开始(数据库分析 + 扩胞)
+ 'step2': 从 Zeo++ 计算开始
+ 'step5': 直接进行结果筛选
+ 'exit': 退出
+ """
+ print("\n请选择操作:")
+
+ options = []
+
+ if status['has_zeo_results']:
+ options.append(('5', '直接进行结果筛选 (Step 5)'))
+ options.append(('2', '重新运行 Zeo++ 计算 (Step 2-4)'))
+ options.append(('1', '从头开始 (数据库分析 + 扩胞)'))
+ elif status['has_processed_data']:
+ options.append(('2', '运行 Zeo++ 计算 (Step 2-4)'))
+ options.append(('1', '从头开始 (数据库分析 + 扩胞)'))
+ else:
+ options.append(('1', '从头开始 (数据库分析 + 扩胞)'))
+
+ options.append(('q', '退出'))
+
+ for key, desc in options:
+ print(f" {key}. {desc}")
+
+ choice = input("\n请选择 [默认: " + options[0][0] + "]: ").strip().lower()
+
+ if not choice:
+ choice = options[0][0]
+
+ if choice == 'q':
+ return 'exit'
+ elif choice == '5':
+ return 'step5'
+ elif choice == '2':
+ return 'step2'
+ else:
+ return 'step1'
+
+
+def run_step1_database_analysis(env: dict, cation: str) -> dict:
+ """
+ Step 1: 数据库分析与扩胞处理
+
+ Returns:
+ 处理参数字典,如果用户取消则返回 None
+ """
+ print("\n" + "═" * 60)
+ print("【Step 1: 数据库分析与扩胞处理】")
+ print("═" * 60)
+
# 数据库路径
while True:
db_path = input("\n📂 请输入数据库路径: ").strip()
@@ -48,9 +191,8 @@ def get_user_input(env: dict):
print(f"❌ 路径不存在: {db_path}")
# 检测当前Conda环境路径
- conda_env_path = env.get('conda_prefix', '')
- if not conda_env_path:
- conda_env_path = os.path.expanduser("~/anaconda3/envs/screen")
+ default_conda = "/cluster/home/koko125/anaconda3/envs/screen"
+ conda_env_path = env.get('conda_env', '') or default_conda
print(f"\n检测到Conda环境: {conda_env_path}")
custom_env = input(f"使用此环境? [Y/n] 或输入其他路径: ").strip()
@@ -59,9 +201,6 @@ def get_user_input(env: dict):
elif custom_env and custom_env.lower() != 'y':
conda_env_path = custom_env
- # 目标阳离子
- cation = input("🎯 请输入目标阳离子 [默认: Li]: ").strip() or "Li"
-
# 目标阴离子
anion_input = input("🎯 请输入目标阴离子 (逗号分隔) [默认: O,S,Cl,Br]: ").strip()
anions = set(a.strip() for a in anion_input.split(',')) if anion_input else {'O', 'S', 'Cl', 'Br'}
@@ -79,43 +218,18 @@ def get_user_input(env: dict):
print("【并行计算配置】")
default_cores = min(env['total_cores'], 32)
- cores_input = input(f"💻 最大可用核数 [默认: {default_cores}]: ").strip()
- max_cores = int(cores_input) if cores_input.isdigit() else default_cores
+ cores_input = input(f"💻 最大可用核数/Worker数 [默认: {default_cores}]: ").strip()
+ max_workers = int(cores_input) if cores_input.isdigit() else default_cores
- print("\n任务复杂度 (影响每个Worker分配的核数):")
- print(" 1. 低 (1核/Worker) - 简单IO操作")
- print(" 2. 中 (2核/Worker) - 结构解析+检查 [默认]")
- print(" 3. 高 (4核/Worker) - 复杂计算")
- complexity_choice = input("请选择 [1/2/3]: ").strip()
- complexity = {'1': 'low', '2': 'medium', '3': 'high', '': 'medium'}.get(complexity_choice, 'medium')
-
- # 执行模式
- use_slurm = False
- if env['has_slurm']:
- slurm_choice = input("\n是否使用SLURM提交作业? [y/N]: ").strip().lower()
- use_slurm = slurm_choice == 'y'
-
- return {
+ params = {
'database_path': db_path,
'target_cation': cation,
'target_anions': anions,
'anion_mode': anion_mode,
- 'max_cores': max_cores,
- 'task_complexity': complexity,
- 'use_slurm': use_slurm
+ 'max_workers': max_workers,
+ 'conda_env': conda_env_path,
}
-
-def main():
- """主函数"""
- print_banner()
-
- # 环境检测
- env = detect_and_show_environment()
-
- # 获取用户输入
- params = get_user_input(env)
-
print("\n" + "═" * 60)
print("开始数据库分析...")
print("═" * 60)
@@ -126,43 +240,404 @@ def main():
target_cation=params['target_cation'],
target_anions=params['target_anions'],
anion_mode=params['anion_mode'],
- max_cores=params['max_cores'],
- task_complexity=params['task_complexity']
+ max_cores=params['max_workers'],
+ task_complexity='medium'
)
print(f"\n发现 {len(analyzer.cif_files)} 个CIF文件")
- if params['use_slurm']:
- # SLURM模式
- output_dir = input("输出目录 [默认: ./slurm_output]: ").strip() or "./slurm_output"
- job_id = analyzer.analyze_slurm(output_dir=output_dir)
- print(f"\n✅ SLURM作业已提交: {job_id}")
- print(f" 使用 'squeue -j {job_id}' 查看状态")
- print(f" 结果将保存到: {output_dir}")
- else:
- # 本地模式
- report = analyzer.analyze(show_progress=True)
+ # 执行分析
+ report = analyzer.analyze(show_progress=True)
- # 打印报告
- ReportGenerator.print_report(report, detailed=True)
+ # 打印报告
+ ReportGenerator.print_report(report, detailed=True)
- # 保存选项
- save_choice = input("\n是否保存报告? [y/N]: ").strip().lower()
- if save_choice == 'y':
- output_path = input("报告路径 [默认: analysis_report.json]: ").strip()
- output_path = output_path or "analysis_report.json"
- report.save(output_path)
- print(f"✅ 报告已保存到: {output_path}")
+ # 保存选项
+ save_choice = input("\n是否保存报告? [y/N]: ").strip().lower()
+ if save_choice == 'y':
+ output_path = input("报告路径 [默认: analysis_report.json]: ").strip()
+ output_path = output_path or "analysis_report.json"
+ report.save(output_path)
+ print(f"✅ 报告已保存到: {output_path}")
- # CSV导出
- csv_choice = input("是否导出详细CSV? [y/N]: ").strip().lower()
- if csv_choice == 'y':
- csv_path = input("CSV路径 [默认: analysis_details.csv]: ").strip()
- csv_path = csv_path or "analysis_details.csv"
- ReportGenerator.export_to_csv(report, csv_path)
+ # CSV导出
+ csv_choice = input("是否导出详细CSV? [y/N]: ").strip().lower()
+ if csv_choice == 'y':
+ csv_path = input("CSV路径 [默认: analysis_details.csv]: ").strip()
+ csv_path = csv_path or "analysis_details.csv"
+ ReportGenerator.export_to_csv(report, csv_path)
- print("\n✅ 分析完成!")
+ # 生成最终数据库
+ process_choice = input("\n是否生成最终可用的数据库(扩胞+添加化合价)? [Y/n]: ").strip().lower()
+ if process_choice == 'n':
+ print("\n已跳过扩胞处理,可稍后继续")
+ return params
+
+ # 输出目录设置
+ print("\n输出目录设置:")
+ flat_dir = input(" 原始格式输出目录 [默认: workspace/processed]: ").strip()
+ flat_dir = flat_dir or "workspace/processed"
+
+ analysis_dir = input(" 分析格式输出目录 [默认: workspace/data]: ").strip()
+ analysis_dir = analysis_dir or "workspace/data"
+
+ # 扩胞保存数量
+ keep_input = input("\n扩胞结构保存数量 [默认: 1]: ").strip()
+ keep_number = int(keep_input) if keep_input.isdigit() and int(keep_input) > 0 else 1
+
+ # 扩胞精度选择
+ print("\n扩胞计算精度:")
+ print(" 1. 高精度 (精确分数)")
+ print(" 2. 普通精度 (分母≤100)")
+ print(" 3. 低精度 (分母≤10) [默认]")
+ print(" 4. 极低精度 (分母≤5)")
+ precision_choice = input("请选择 [1/2/3/4]: ").strip()
+ calculate_type = {
+ '1': 'high', '2': 'normal', '3': 'low', '4': 'very_low', '': 'low'
+ }.get(precision_choice, 'low')
+
+ # 获取可处理文件
+ processable = report.get_processable_files(include_needs_expansion=True)
+
+ if not processable:
+ print("⚠️ 没有可处理的文件")
+ return params
+
+ print(f"\n发现 {len(processable)} 个可处理的文件")
+
+ # 准备文件列表和扩胞标记
+ input_files = [info.file_path for info in processable]
+ needs_expansion_flags = [info.needs_expansion for info in processable]
+ anion_types_list = [info.anion_types for info in processable]
+
+ direct_count = sum(1 for f in needs_expansion_flags if not f)
+ expansion_count = sum(1 for f in needs_expansion_flags if f)
+ print(f" - 可直接处理: {direct_count}")
+ print(f" - 需要扩胞: {expansion_count}")
+ print(f" - 扩胞保存数: {keep_number}")
+
+ confirm = input("\n确认开始处理? [Y/n]: ").strip().lower()
+ if confirm == 'n':
+ print("\n已取消处理")
+ return params
+
+ print("\n" + "═" * 60)
+ print("开始处理结构文件...")
+ print("═" * 60)
+
+ # 创建处理器
+ processor = StructureProcessor(
+ calculate_type=calculate_type,
+ keep_number=keep_number,
+ target_cation=params['target_cation']
+ )
+
+ # 创建输出目录
+ os.makedirs(flat_dir, exist_ok=True)
+ os.makedirs(analysis_dir, exist_ok=True)
+
+ results = []
+ total = len(input_files)
+
+ import shutil
+ for i, (input_file, needs_exp, anion_types) in enumerate(
+ zip(input_files, needs_expansion_flags, anion_types_list)
+ ):
+ print(f"\r处理进度: {i+1}/{total} - {os.path.basename(input_file)}", end="")
+
+ # 处理文件到原始格式目录
+ result = processor.process_file(input_file, flat_dir, needs_exp)
+ results.append(result)
+
+ if result.success:
+ # 同时保存到分析格式目录
+ # 按阴离子类型创建子目录
+ anion_key = '+'.join(sorted(anion_types)) if anion_types else 'other'
+ anion_dir = os.path.join(analysis_dir, anion_key)
+ os.makedirs(anion_dir, exist_ok=True)
+
+ # 获取基础文件名
+ base_name = os.path.splitext(os.path.basename(input_file))[0]
+
+ # 创建以文件名命名的子目录
+ file_dir = os.path.join(anion_dir, base_name)
+ os.makedirs(file_dir, exist_ok=True)
+
+ # 复制生成的文件到分析格式目录
+ for output_file in result.output_files:
+ dst_path = os.path.join(file_dir, os.path.basename(output_file))
+ shutil.copy2(output_file, dst_path)
+
+ print()
+
+ # 统计结果
+ success_count = sum(1 for r in results if r.success)
+ fail_count = sum(1 for r in results if not r.success)
+ total_output = sum(len(r.output_files) for r in results if r.success)
+
+ print("\n" + "-" * 60)
+ print("【处理结果统计】")
+ print("-" * 60)
+ print(f" 成功处理: {success_count}")
+ print(f" 处理失败: {fail_count}")
+ print(f" 生成文件: {total_output}")
+ print(f"\n 原始格式目录: {flat_dir}")
+ print(f" 分析格式目录: {analysis_dir}")
+ print(f" └── 结构: data/阴离子类型/文件名/文件名.cif")
+
+ # 显示失败的文件
+ if fail_count > 0:
+ print("\n失败的文件:")
+ for r in results:
+ if not r.success:
+ print(f" - {os.path.basename(r.input_file)}: {r.error_message}")
+
+ print("\n✅ Step 1 完成!")
+ return params
+
+
+def run_step2_zeo_analysis(params: dict, ws_manager: WorkspaceManager = None, workspace_info = None):
+ """
+ Step 2-4: Zeo++ Voronoi 分析
+
+ Args:
+ params: 参数字典,包含 target_cation 等
+ ws_manager: 工作区管理器(可选,如果已有则直接使用)
+ workspace_info: 工作区信息(可选,如果已有则直接使用)
+ """
+
+ print("\n" + "═" * 60)
+ print("【Step 2-4: Zeo++ Voronoi 分析】")
+ print("═" * 60)
+
+ # 如果没有传入 ws_manager,则创建新的
+ if ws_manager is None:
+ # 工作区路径
+ workspace_path = input("\n工作区路径 [默认: workspace]: ").strip() or "workspace"
+ tool_dir = input("工具目录路径 [默认: tool]: ").strip() or "tool"
+
+ # 创建工作区管理器
+ ws_manager = WorkspaceManager(
+ workspace_path=workspace_path,
+ tool_dir=tool_dir,
+ target_cation=params.get('target_cation', 'Li')
+ )
+
+ # 如果没有传入 workspace_info,则检查现有工作区
+ if workspace_info is None:
+ existing = ws_manager.check_existing_workspace()
+
+ if existing and existing.total_structures > 0:
+ ws_manager.print_workspace_summary(existing)
+ workspace_info = existing
+ else:
+ print("⚠️ 工作区数据目录不存在或为空")
+ print(f" 请先运行 Step 1 生成处理后的数据到: {ws_manager.data_dir}")
+ return
+
+ # 检查是否需要创建软链接
+ if workspace_info.linked_structures < workspace_info.total_structures:
+ print("\n正在创建配置文件软链接...")
+ workspace_info = ws_manager.setup_workspace(force_relink=False)
+
+ # 获取计算任务
+ tasks = ws_manager.get_computation_tasks(workspace_info)
+
+ if not tasks:
+ print("⚠️ 没有找到可计算的任务")
+ return
+
+ print(f"\n发现 {len(tasks)} 个计算任务")
+
+ # Zeo++ 环境配置
+ print("\n" + "-" * 50)
+ print("【Zeo++ 计算配置】")
+
+ default_zeo_env = "/cluster/home/koko125/anaconda3/envs/zeo"
+ zeo_env = input(f"Zeo++ Conda环境 [默认: {default_zeo_env}]: ").strip()
+ zeo_env = zeo_env or default_zeo_env
+
+ # SLURM 配置
+ partition = input("SLURM分区 [默认: cpu]: ").strip() or "cpu"
+
+ max_concurrent_input = input("最大并发任务数 [默认: 50]: ").strip()
+ max_concurrent = int(max_concurrent_input) if max_concurrent_input.isdigit() else 50
+
+ time_limit = input("单任务时间限制 [默认: 2:00:00]: ").strip() or "2:00:00"
+
+ # 创建配置
+ zeo_config = ZeoConfig(
+ conda_env=zeo_env,
+ partition=partition,
+ max_concurrent=max_concurrent,
+ time_limit=time_limit
+ )
+
+ # 确认执行
+ print("\n" + "-" * 50)
+ print("【计算任务确认】")
+ print(f" 总任务数: {len(tasks)}")
+ print(f" Conda环境: {zeo_config.conda_env}")
+ print(f" SLURM分区: {zeo_config.partition}")
+ print(f" 最大并发: {zeo_config.max_concurrent}")
+ print(f" 时间限制: {zeo_config.time_limit}")
+
+ confirm = input("\n确认提交计算任务? [Y/n]: ").strip().lower()
+ if confirm == 'n':
+ print("已取消")
+ return
+
+ # 创建执行器并运行
+ executor = ZeoExecutor(zeo_config)
+
+ log_dir = os.path.join(ws_manager.workspace_path, "slurm_logs")
+ results = executor.run_batch(tasks, output_dir=log_dir)
+
+ # 打印结果摘要
+ executor.print_results_summary(results)
+
+ # 保存结果
+ results_file = os.path.join(ws_manager.workspace_path, "zeo_results.json")
+ results_data = [
+ {
+ 'task_id': r.task_id,
+ 'structure_name': r.structure_name,
+ 'cif_path': r.cif_path,
+ 'success': r.success,
+ 'output_files': r.output_files,
+ 'error_message': r.error_message
+ }
+ for r in results
+ ]
+ with open(results_file, 'w') as f:
+ json.dump(results_data, f, indent=2)
+ print(f"\n结果已保存到: {results_file}")
+
+ print("\n✅ Step 2-4 完成!")
+
+
+def run_step5_result_processing(workspace_path: str = "workspace"):
+ """
+ Step 5: 结果处理与筛选
+
+ Args:
+ workspace_path: 工作区路径
+ """
+ print("\n" + "═" * 60)
+ print("【Step 5: 结果处理与筛选】")
+ print("═" * 60)
+
+ # 创建结果处理器
+ processor = ResultProcessor(workspace_path=workspace_path)
+
+ # 获取筛选条件
+ print("\n设置筛选条件:")
+ print(" (直接回车使用默认值,输入 0 表示不限制)")
+
+ # 最小渗透直径(默认改为 1.0)
+ perc_input = input(" 最小渗透直径 (Å) [默认: 1.0]: ").strip()
+ min_percolation = float(perc_input) if perc_input else 1.0
+
+ # 最小 d 值
+ d_input = input(" 最小 d 值 [默认: 2.0]: ").strip()
+ min_d = float(d_input) if d_input else 2.0
+
+ # 最大节点长度
+ node_input = input(" 最大节点长度 (Å) [默认: 不限制]: ").strip()
+ max_node = float(node_input) if node_input else float('inf')
+
+ # 创建筛选条件
+ criteria = FilterCriteria(
+ min_percolation_diameter=min_percolation,
+ min_d_value=min_d,
+ max_node_length=max_node
+ )
+
+ # 执行处理
+ results, stats = processor.process_and_filter(
+ criteria=criteria,
+ save_csv=True,
+ copy_passed=True
+ )
+
+ # 打印摘要
+ processor.print_summary(results, stats)
+
+ # 显示输出位置
+ print("\n输出文件位置:")
+ print(f" 汇总 CSV: {workspace_path}/results/summary.csv")
+ print(f" 分类 CSV: {workspace_path}/results/阴离子类型/阴离子类型.csv")
+ print(f" 通过筛选: {workspace_path}/passed/阴离子类型/结构名/")
+
+ print("\n✅ Step 5 完成!")
+
+
+def main():
+ """主函数"""
+ print_banner()
+
+ # 环境检测
+ env = detect_and_show_environment()
+
+ # 询问目标阳离子
+ cation = input("\n🎯 请输入目标阳离子 [默认: Li]: ").strip() or "Li"
+
+ # 检测工作流程状态
+ status = detect_workflow_status(workspace_path="workspace", target_cation=cation)
+ print_workflow_status(status)
+
+ # 获取用户选择
+ choice = get_user_choice(status)
+
+ if choice == 'exit':
+ print("\n👋 再见!")
+ return
+
+ # 根据选择执行相应步骤
+ params = {'target_cation': cation}
+
+ if choice == 'step1':
+ # 从头开始
+ params = run_step1_database_analysis(env, cation)
+ if params is None:
+ return
+
+ # 询问是否继续
+ continue_choice = input("\n是否继续进行 Zeo++ 计算? [Y/n]: ").strip().lower()
+ if continue_choice == 'n':
+ print("\n可稍后运行程序继续 Zeo++ 计算")
+ return
+
+ # 重新检测状态
+ status = detect_workflow_status(workspace_path="workspace", target_cation=cation)
+ run_step2_zeo_analysis(params, status['ws_manager'], status['workspace_info'])
+
+ # 询问是否继续筛选
+ continue_choice = input("\n是否继续进行结果筛选? [Y/n]: ").strip().lower()
+ if continue_choice == 'n':
+ print("\n可稍后运行程序继续结果筛选")
+ return
+
+ run_step5_result_processing("workspace")
+
+ elif choice == 'step2':
+ # 从 Zeo++ 计算开始
+ run_step2_zeo_analysis(params, status['ws_manager'], status['workspace_info'])
+
+ # 询问是否继续筛选
+ continue_choice = input("\n是否继续进行结果筛选? [Y/n]: ").strip().lower()
+ if continue_choice == 'n':
+ print("\n可稍后运行程序继续结果筛选")
+ return
+
+ run_step5_result_processing("workspace")
+
+ elif choice == 'step5':
+ # 直接进行结果筛选
+ run_step5_result_processing("workspace")
+
+ print("\n✅ 全部完成!")
if __name__ == "__main__":
- main()
\ No newline at end of file
+ main()
diff --git a/readme.md b/readme.md
index 2f177ff..0c35ce3 100644
--- a/readme.md
+++ b/readme.md
@@ -1,4 +1,4 @@
-# 高通量筛选与扩胞项目
+# 高通量筛选与扩胞项目 v2.3
## 环境配置需求
@@ -12,9 +12,29 @@
### 2. screen 环境 (用于逻辑筛选与数据处理)
* **Python**: 3.11.4
* **核心库**: `pymatgen==2024.11.13`, `pandas` (新增,用于处理CSV)
+* **路径**: `/cluster/home/koko125/anaconda3/envs/screen`
## 快速开始
+### 方式一:使用新版主程序(推荐)
+
+```bash
+# 激活 screen 环境
+conda activate /cluster/home/koko125/anaconda3/envs/screen
+
+# 运行主程序
+python main.py
+```
+
+主程序提供交互式界面,支持:
+- 数据库分析与筛选
+- 本地多进程并行
+- SLURM 直接提交(无需生成脚本文件)
+- 实时进度条显示
+- 扩胞处理与化合价添加
+
+### 方式二:传统方式
+
1. **数据准备**:
* 如果数据来源为 **Materials Project (MP)**,请将 CIF 文件放入 `data/input_pre`。
* 如果数据来源为 **ICSD**,请直接将 CIF 文件放入 `data/input`。
@@ -25,6 +45,52 @@
bash main.sh
```
+## 新版功能特性 (v2.1)
+
+### 执行模式
+
+1. **本地多进程模式** (`local`)
+ - 使用 Python multiprocessing 在本地并行执行
+ - 适合小规模任务或测试
+
+2. **SLURM 直接提交模式** (`slurm`)
+ - 直接在 Python 中提交 SLURM 作业
+ - 无需生成脚本文件
+ - 实时监控作业状态和进度
+ - 适合大规模高通量计算
+
+### 进度显示
+
+```
+分析CIF文件: |████████████████████░░░░░░░░░░░░░░░░░░░░| 500/1000 (50.0%) [0:05:23<0:05:20, 1.6it/s] ✓480 ✗20
+```
+
+### 输出格式
+
+处理后的文件支持两种输出格式:
+
+1. **原始格式**(平铺)
+ ```
+ workspace/processed/
+ ├── 1514027.cif
+ ├── 1514072.cif
+ └── ...
+ ```
+
+2. **分析格式**(按阴离子分类)
+ ```
+ workspace/data/
+ ├── O/
+ │ ├── 1514027/
+ │ │ └── 1514027.cif
+ │ └── 1514072/
+ │ └── 1514072.cif
+ ├── S/
+ │ └── ...
+ └── Cl+O/
+ └── ...
+ ```
+
## 处理流程详解
### Stage 1: 预处理与基础筛选 (Step 1)
@@ -51,9 +117,12 @@
---
-## 扩胞逻辑 (Step 5 - 待后续执行)
+## 扩胞逻辑 (Step 5)
-目前扩胞逻辑维持原状,基于筛选后的结构进行处理。
+扩胞处理已集成到新版主程序中,支持:
+- 自动计算扩胞因子
+- 可选保存数量(当为1时不加后缀)
+- 自动添加化合价信息
### 算法分解
1. **读取结构**: 解析 CIF 文件。
@@ -71,4 +140,276 @@
### 假设条件
* 只考虑两个原子在同一位置上的共占位情况。
-* 不考虑 Li 原子的共占位情况,对 Li 原子不做处理。
\ No newline at end of file
+* 不考虑 Li 原子的共占位情况,对 Li 原子不做处理。
+
+## 项目结构
+
+```
+screen/
+├── main.py # 主入口(新版)
+├── main.sh # 传统脚本入口
+├── readme.md # 本文档
+├── config/ # 配置文件
+│ ├── settings.yaml
+│ └── valence_states.yaml
+├── src/ # 源代码
+│ ├── analysis/ # 分析模块
+│ │ ├── database_analyzer.py
+│ │ ├── report_generator.py
+│ │ ├── structure_inspector.py
+│ │ └── worker.py
+│ ├── core/ # 核心模块
+│ │ ├── executor.py # 任务执行器(新)
+│ │ ├── scheduler.py # 调度器
+│ │ └── progress.py # 进度管理
+│ ├── preprocessing/ # 预处理模块
+│ │ ├── processor.py # 结构处理器
+│ │ └── ...
+│ └── utils/ # 工具函数
+├── py/ # 传统脚本
+├── tool/ # 工具和配置
+│ ├── analyze_voronoi_nodes.py
+│ └── Li/ # 化合价配置
+└── workspace/ # 工作区
+ ├── data/ # 分析格式输出
+ └── processed/ # 原始格式输出
+```
+
+## API 使用示例
+
+### 使用执行器
+
+```python
+from src.core.executor import create_executor, TaskExecutor
+
+# 创建执行器
+executor = create_executor(
+ mode="slurm", # 或 "local"
+ max_workers=32,
+ conda_env="/cluster/home/koko125/anaconda3/envs/screen"
+)
+
+# 定义任务
+tasks = [
+ (file_path, "Li", {"O", "S"})
+ for file_path in cif_files
+]
+
+# 执行
+from src.analysis.worker import analyze_single_file
+results = executor.run(tasks, analyze_single_file, desc="分析CIF文件")
+```
+
+### 使用数据库分析器
+
+```python
+from src.analysis.database_analyzer import DatabaseAnalyzer
+
+analyzer = DatabaseAnalyzer(
+ database_path="/path/to/cif/files",
+ target_cation="Li",
+ target_anions={"O", "S", "Cl", "Br"},
+ anion_mode="all"
+)
+
+report = analyzer.analyze(show_progress=True)
+report.save("analysis_report.json")
+```
+
+### 使用 Zeo++ 执行器
+
+```python
+from src.computation.workspace_manager import WorkspaceManager
+from src.computation.zeo_executor import ZeoExecutor, ZeoConfig
+
+# 设置工作区
+ws_manager = WorkspaceManager(
+ workspace_path="workspace",
+ tool_dir="tool",
+ target_cation="Li"
+)
+
+# 创建软链接
+workspace_info = ws_manager.setup_workspace()
+
+# 获取计算任务
+tasks = ws_manager.get_computation_tasks(workspace_info)
+
+# 配置 Zeo++ 执行器
+config = ZeoConfig(
+ conda_env="/cluster/home/koko125/anaconda3/envs/zeo",
+ partition="cpu",
+ max_concurrent=50,
+ time_limit="2:00:00"
+)
+
+# 执行计算
+executor = ZeoExecutor(config)
+results = executor.run_batch(tasks, output_dir="slurm_logs")
+executor.print_results_summary(results)
+```
+
+## 新版功能特性 (v2.2)
+
+### Zeo++ Voronoi 分析
+
+新增 SLURM 作业数组支持,可高效调度大量 Zeo++ 计算任务:
+
+1. **自动工作区管理**
+ - 检测现有工作区数据
+ - 自动创建配置文件软链接
+ - 按阴离子类型组织目录结构
+
+2. **SLURM 作业数组**
+ - 使用 `--array` 参数批量提交任务
+ - 支持最大并发数限制(如 `%50`)
+ - 自动分批处理超大任务集
+
+3. **实时进度监控**
+ - 通过状态文件跟踪任务完成情况
+ - 支持 Ctrl+C 中断监控(作业继续运行)
+ - 自动收集输出文件
+
+### 工作流程
+
+```
+Step 1: 数据库分析
+ ↓
+Step 1.5: 扩胞处理 + 化合价添加
+ ↓
+Step 2-4: Zeo++ Voronoi 分析
+ ├── 创建软链接 (yaml 配置 + 计算脚本)
+ ├── 提交 SLURM 作业数组
+ └── 监控进度并收集结果
+ ↓
+Step 5: 结果处理与筛选
+ ├── 从 log.txt 提取关键参数
+ ├── 汇总到 CSV 文件
+ ├── 应用筛选条件
+ └── 复制通过筛选的结构到 passed/ 目录
+```
+
+### 筛选条件
+
+Step 5 支持以下筛选条件:
+- **最小渗透直径** (Percolation Diameter): 默认 1.0 Å
+- **最小 d 值** (Minimum of d): 默认 2.0
+- **最大节点长度** (Maximum Node Length): 默认不限制
+
+### 日志输出
+
+Zeo++ 计算的输出会重定向到每个结构目录下的 `log.txt` 文件:
+```bash
+python analyze_voronoi_nodes.py *.cif -i O.yaml > log.txt 2>&1
+```
+
+日志中包含的关键信息:
+- `Percolation diameter (A): X.XX` - 渗透直径
+- `the minium of d\nX.XX` - 最小 d 值
+- `Maximum node length detected: X.XX A` - 最大节点长度
+
+### 目录结构
+
+```
+workspace/
+├── data/ # 分析格式数据
+│ ├── O/ # 氧化物
+│ │ ├── O.yaml -> tool/Li/O.yaml
+│ │ ├── analyze_voronoi_nodes.py -> tool/analyze_voronoi_nodes.py
+│ │ ├── 1514027/
+│ │ │ ├── 1514027.cif
+│ │ │ ├── 1514027_all_accessed_node.cif # Zeo++ 输出
+│ │ │ ├── 1514027_bond_valence_filtered.cif
+│ │ │ └── 1514027_bv_info.csv
+│ │ └── ...
+│ ├── S/ # 硫化物
+│ └── Cl+O/ # 复合阴离子
+├── processed/ # 原始格式数据
+├── slurm_logs/ # SLURM 日志
+│ ├── tasks.json
+│ ├── submit_array.sh
+│ ├── task_0.out
+│ ├── task_0.err
+│ ├── status_0.txt
+│ └── ...
+└── zeo_results.json # 计算结果汇总
+```
+
+### 配置文件说明
+
+`tool/Li/O.yaml` 示例:
+```yaml
+SPECIE: Li+
+ANION: O
+PERCO_R: 0.5
+NEIGHBOR: 1.8
+LONG: 2.2
+```
+
+参数说明:
+- `SPECIE`: 目标扩散离子(带电荷)
+- `ANION`: 阴离子类型
+- `PERCO_R`: 渗透半径阈值
+- `NEIGHBOR`: 邻近距离阈值
+- `LONG`: 长节点判定阈值
+
+## 新版功能特性 (v2.3)
+
+### 断点续做功能
+
+v2.3 新增智能断点续做功能,支持从任意步骤继续执行:
+
+1. **自动状态检测**
+ - 启动时自动检测工作流程状态
+ - 检测 `workspace/data/` 是否存在且有结构 → 判断 Step 1 是否完成
+ - 检测结构目录下是否有 `log.txt` → 判断 Zeo++ 计算是否完成
+ - 如果 50% 以上结构有 log.txt,认为 Zeo++ 计算已完成
+
+2. **智能流程跳转**
+ - 如果已完成 Zeo++ 计算 → 可直接进行筛选
+ - 如果已完成扩胞处理 → 可直接进行 Zeo++ 计算
+ - 从后往前检测,自动跳过已完成的步骤
+
+3. **分步执行与中断**
+ - 每个大步骤完成后询问是否继续
+ - 支持中途退出,下次运行时自动检测进度
+ - 三大步骤:扩胞与加化合价 → Zeo++ 计算 → 筛选
+
+### 工作流程状态示例
+
+```
+ 工作流程状态检测
+
+检测到现有工作区数据:
+ - 结构总数: 1234
+ - 已完成 Zeo++ 计算: 1200 (97.2%)
+ - 未完成 Zeo++ 计算: 34
+
+当前状态: Zeo++ 计算已完成
+
+可选操作:
+ [1] 直接进行结果筛选 (Step 5)
+ [2] 重新运行 Zeo++ 计算 (Step 2-4)
+ [3] 从头开始 (Step 1)
+ [0] 退出
+
+请选择 [1]:
+```
+
+### 使用场景
+
+1. **首次运行**
+ - 从 Step 1 开始完整执行
+ - 每步完成后可选择继续或退出
+
+2. **中断后继续**
+ - 自动检测已完成的步骤
+ - 提供从当前进度继续的选项
+
+3. **重新筛选**
+ - 修改筛选条件后
+ - 可直接运行 Step 5 而无需重新计算
+
+4. **部分重算**
+ - 如需重新计算部分结构
+ - 可选择重新运行 Zeo++ 计算
diff --git a/src/__init__.py b/src/__init__.py
index e69de29..3a1c37e 100644
--- a/src/__init__.py
+++ b/src/__init__.py
@@ -0,0 +1,12 @@
+"""
+高通量筛选与扩胞项目 - 源代码包
+"""
+
+from . import analysis
+from . import core
+from . import preprocessing
+from . import computation
+from . import utils
+
+__version__ = "2.2.0"
+__all__ = ['analysis', 'core', 'preprocessing', 'computation', 'utils']
diff --git a/src/analysis/database_analyzer.py b/src/analysis/database_analyzer.py
index 9c0cb5a..7e9e179 100644
--- a/src/analysis/database_analyzer.py
+++ b/src/analysis/database_analyzer.py
@@ -71,6 +71,151 @@ class DatabaseReport:
})
expansion_factor_distribution: Dict[int, int] = field(default_factory=dict)
+ def to_dict(self) -> dict:
+ """转换为可序列化的字典"""
+ from dataclasses import fields as dataclass_fields
+
+ def convert_value(val):
+ """递归转换值为可序列化类型"""
+ if isinstance(val, set):
+ return list(val)
+ elif isinstance(val, dict):
+ return {k: convert_value(v) for k, v in val.items()}
+ elif isinstance(val, list):
+ return [convert_value(item) for item in val]
+ elif hasattr(val, '__dataclass_fields__'):
+ # 处理 dataclass 对象
+ return {k: convert_value(v) for k, v in asdict(val).items()}
+ else:
+ return val
+
+ result = {}
+ for f in dataclass_fields(self):
+ value = getattr(self, f.name)
+ result[f.name] = convert_value(value)
+
+ return result
+
+ def save(self, path: str):
+ """保存报告到JSON文件"""
+ with open(path, 'w', encoding='utf-8') as f:
+ json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
+ print(f"✅ 报告已保存到: {path}")
+
+ @classmethod
+ def load(cls, path: str) -> 'DatabaseReport':
+ """从JSON文件加载报告"""
+ with open(path, 'r', encoding='utf-8') as f:
+ d = json.load(f)
+
+ # 处理 set 类型
+ if 'target_anions' in d:
+ d['target_anions'] = set(d['target_anions'])
+ # 处理 StructureInfo 列表(简化处理,不恢复完整对象)
+ if 'all_structures' in d:
+ d['all_structures'] = []
+
+ return cls(**d)
+
+ def get_processable_files(self, include_needs_expansion: bool = True) -> List[StructureInfo]:
+ """
+ 获取可处理的文件列表
+
+ Args:
+ include_needs_expansion: 是否包含需要扩胞的文件
+
+ Returns:
+ 可处理的 StructureInfo 列表
+ """
+ result = []
+ for info in self.all_structures:
+ if info is None or not info.is_valid:
+ continue
+ if not info.contains_target_cation:
+ continue
+ if not info.can_process:
+ continue
+ if not include_needs_expansion and info.needs_expansion:
+ continue
+ result.append(info)
+ return result
+
+ def copy_processable_files(
+ self,
+ output_dir: str,
+ include_needs_expansion: bool = True,
+ organize_by_anion: bool = True
+ ) -> Dict[str, int]:
+ """
+ 将可处理的CIF文件复制到工作区
+
+ Args:
+ output_dir: 输出目录(如 workspace/data)
+ include_needs_expansion: 是否包含需要扩胞的文件
+ organize_by_anion: 是否按阴离子类型组织子目录
+
+ Returns:
+ 复制统计信息 {类别: 数量}
+ """
+ import shutil
+
+ # 创建输出目录
+ os.makedirs(output_dir, exist_ok=True)
+
+ # 获取可处理文件
+ processable = self.get_processable_files(include_needs_expansion)
+
+ stats = {
+ 'direct': 0, # 可直接处理
+ 'needs_expansion': 0, # 需要扩胞
+ 'total': 0
+ }
+
+ # 按类型创建子目录
+ if organize_by_anion:
+ anion_dirs = {}
+
+ for info in processable:
+ # 确定目标目录
+ if organize_by_anion and info.anion_types:
+ # 使用主要阴离子作为目录名
+ anion_key = '+'.join(sorted(info.anion_types))
+ if anion_key not in anion_dirs:
+ anion_dir = os.path.join(output_dir, anion_key)
+ os.makedirs(anion_dir, exist_ok=True)
+ anion_dirs[anion_key] = anion_dir
+ target_dir = anion_dirs[anion_key]
+ else:
+ target_dir = output_dir
+
+ # 进一步按处理类型分类
+ if info.needs_expansion:
+ sub_dir = os.path.join(target_dir, 'needs_expansion')
+ stats['needs_expansion'] += 1
+ else:
+ sub_dir = os.path.join(target_dir, 'direct')
+ stats['direct'] += 1
+
+ os.makedirs(sub_dir, exist_ok=True)
+
+ # 复制文件
+ src_path = info.file_path
+ dst_path = os.path.join(sub_dir, info.file_name)
+
+ try:
+ shutil.copy2(src_path, dst_path)
+ stats['total'] += 1
+ except Exception as e:
+ print(f"⚠️ 复制失败 {info.file_name}: {e}")
+
+ # 打印统计
+ print(f"\n📁 文件已复制到: {output_dir}")
+ print(f" 可直接处理: {stats['direct']}")
+ print(f" 需要扩胞: {stats['needs_expansion']}")
+ print(f" 总计: {stats['total']}")
+
+ return stats
+
class DatabaseAnalyzer:
"""数据库分析器 - 支持高性能并行"""
@@ -232,6 +377,10 @@ class DatabaseAnalyzer:
report.invalid_files += 1
continue # 无效文件不继续统计
+ # 关键修复:只有当结构确实含有目标阳离子时才计入统计
+ if not info.contains_target_cation:
+ continue # 不含目标阳离子的文件不继续统计
+
report.cation_containing_count += 1
for anion in info.anion_types:
@@ -317,45 +466,3 @@ class DatabaseAnalyzer:
for anion, count in report.anion_distribution.items():
report.anion_ratios[anion] = \
count / report.cation_containing_count
-
- def to_dict(self) -> dict:
- """转换为可序列化的字典"""
- import json
- from dataclasses import asdict, fields
-
- result = {}
- for field in fields(self):
- value = getattr(self, field.name)
-
- # 处理 set 类型
- if isinstance(value, set):
- result[field.name] = list(value)
- # 处理 StructureInfo 列表
- elif field.name == 'all_structures':
- result[field.name] = [] # 不保存详细结构信息,太大
- else:
- result[field.name] = value
-
- return result
-
- def save(self, path: str):
- """保存报告到JSON文件"""
- import json
- with open(path, 'w', encoding='utf-8') as f:
- json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
- print(f"✅ 报告已保存到: {path}")
-
- @classmethod
- def load(cls, path: str) -> 'DatabaseReport':
- """从JSON文件加载报告"""
- import json
- with open(path, 'r', encoding='utf-8') as f:
- d = json.load(f)
-
- # 处理 set 类型
- if 'target_anions' in d:
- d['target_anions'] = set(d['target_anions'])
- if 'all_structures' not in d:
- d['all_structures'] = []
-
- return cls(**d)
\ No newline at end of file
diff --git a/src/analysis/report_generator.py b/src/analysis/report_generator.py
index 0651409..172f6d1 100644
--- a/src/analysis/report_generator.py
+++ b/src/analysis/report_generator.py
@@ -145,7 +145,7 @@ class ReportGenerator:
info.anion_mode,
info.has_oxidation_states,
info.has_partial_occupancy,
- info.cation_has_partial_occupancy,
+ info.cation_with_other_cation, # 修复:使用正确的属性名
info.anion_has_partial_occupancy,
info.needs_expansion,
info.is_binary_compound,
@@ -156,4 +156,4 @@ class ReportGenerator:
]
writer.writerow(row)
- print(f"详细结果已导出到: {output_path}")
\ No newline at end of file
+ print(f"详细结果已导出到: {output_path}")
diff --git a/src/analysis/structure_inspector.py b/src/analysis/structure_inspector.py
index 4b171f9..fccf918 100644
--- a/src/analysis/structure_inspector.py
+++ b/src/analysis/structure_inspector.py
@@ -441,43 +441,3 @@ class StructureInspector:
return False
except:
return False
-
- def _evaluate_processability(self, info: StructureInfo):
- """评估可处理性"""
- skip_reasons = []
-
- if not info.is_valid:
- skip_reasons.append("无法解析CIF文件")
-
- if not info.contains_target_cation:
- skip_reasons.append(f"不含{self.target_cation}")
-
- if info.anion_mode == "none":
- skip_reasons.append("不含目标阴离子")
-
- if info.is_binary_compound:
- skip_reasons.append("二元化合物")
-
- if info.has_radioactive_elements:
- skip_reasons.append("含放射性元素")
-
- # 关键:目标阳离子共占位是不可处理的
- if info.cation_has_partial_occupancy:
- skip_reasons.append(f"{self.target_cation}存在共占位")
-
- # 阴离子共占位通常也不处理
- if info.anion_has_partial_occupancy:
- skip_reasons.append("阴离子存在共占位")
-
- if info.has_water_molecule:
- skip_reasons.append("含水分子")
-
- # 扩胞因子过大
- if info.expansion_info.needs_expansion and not info.expansion_info.can_expand:
- skip_reasons.append(info.expansion_info.skip_reason)
-
- if skip_reasons:
- info.can_process = False
- info.skip_reason = "; ".join(skip_reasons)
- else:
- info.can_process = True
\ No newline at end of file
diff --git a/src/computation/__init__.py b/src/computation/__init__.py
new file mode 100644
index 0000000..665a4c2
--- /dev/null
+++ b/src/computation/__init__.py
@@ -0,0 +1,15 @@
+"""
+计算模块:Zeo++ Voronoi 分析
+"""
+from .workspace_manager import WorkspaceManager
+from .zeo_executor import ZeoExecutor, ZeoConfig
+from .result_processor import ResultProcessor, FilterCriteria, StructureResult
+
+__all__ = [
+ 'WorkspaceManager',
+ 'ZeoExecutor',
+ 'ZeoConfig',
+ 'ResultProcessor',
+ 'FilterCriteria',
+ 'StructureResult'
+]
diff --git a/src/computation/result_processor.py b/src/computation/result_processor.py
new file mode 100644
index 0000000..14cb4cf
--- /dev/null
+++ b/src/computation/result_processor.py
@@ -0,0 +1,426 @@
+"""
+Zeo++ 计算结果处理器:提取数据、筛选结构
+"""
+import os
+import re
+import shutil
+from typing import Dict, List, Optional, Tuple
+from dataclasses import dataclass, field
+import pandas as pd
+
+
+@dataclass
+class FilterCriteria:
+ """筛选条件"""
+ min_percolation_diameter: float = 1.0 # 最小渗透直径 (Å),默认 1.0
+ min_d_value: float = 2.0 # 最小 d 值,默认 2.0
+ max_node_length: float = float('inf') # 最大节点长度 (Å)
+
+
+@dataclass
+class StructureResult:
+ """单个结构的计算结果"""
+ structure_name: str
+ anion_type: str
+ work_dir: str
+
+ # 提取的参数
+ percolation_diameter: Optional[float] = None
+ min_d: Optional[float] = None
+ max_node_length: Optional[float] = None
+
+ # 筛选结果
+ passed_filter: bool = False
+ filter_reason: str = ""
+
+
+class ResultProcessor:
+ """
+ Zeo++ 计算结果处理器
+
+ 功能:
+ 1. 从每个结构目录的 log.txt 提取关键参数
+ 2. 汇总所有结果到 CSV 文件
+ 3. 根据筛选条件筛选结构
+ 4. 将通过筛选的结构复制到新文件夹
+ """
+
+ def __init__(
+ self,
+ workspace_path: str = "workspace",
+ data_dir: str = None,
+ output_dir: str = None
+ ):
+ """
+ 初始化结果处理器
+
+ Args:
+ workspace_path: 工作区根目录
+ data_dir: 数据目录(默认 workspace/data)
+ output_dir: 输出目录(默认 workspace/results)
+ """
+ self.workspace_path = os.path.abspath(workspace_path)
+ self.data_dir = data_dir or os.path.join(self.workspace_path, "data")
+ self.output_dir = output_dir or os.path.join(self.workspace_path, "results")
+
+ def extract_from_log(self, log_path: str) -> Tuple[Optional[float], Optional[float], Optional[float]]:
+ """
+ 从 log.txt 中提取三个关键参数
+
+ Args:
+ log_path: log.txt 文件路径
+
+ Returns:
+ (percolation_diameter, min_d, max_node_length)
+ """
+ if not os.path.exists(log_path):
+ return None, None, None
+
+ try:
+ with open(log_path, 'r', encoding='utf-8') as f:
+ content = f.read()
+ except Exception:
+ return None, None, None
+
+ # 正则表达式 - 与 py/extract_data.py 保持一致
+
+ # 1. Percolation diameter: "# Percolation diameter (A): 1.06"
+ re_percolation = r"Percolation diameter \(A\):\s*([\d\.]+)"
+
+ # 2. Minimum of d: "the minium of d\n3.862140561244235"
+ # 注意:这是 Topological_Analysis 库输出的格式
+ re_min_d = r"the minium of d\s*\n\s*([\d\.]+)"
+
+ # 3. Maximum node length: "# Maximum node length detected: 1.332 A"
+ re_max_node = r"Maximum node length detected:\s*([\d\.]+)\s*A"
+
+ # 提取数据
+ match_perc = re.search(re_percolation, content)
+ match_d = re.search(re_min_d, content)
+ match_node = re.search(re_max_node, content)
+
+ val_perc = float(match_perc.group(1)) if match_perc else None
+ val_d = float(match_d.group(1)) if match_d else None
+ val_node = float(match_node.group(1)) if match_node else None
+
+ return val_perc, val_d, val_node
+
+ def process_all_structures(self) -> List[StructureResult]:
+ """
+ 处理所有结构,提取计算结果
+
+ Returns:
+ StructureResult 列表
+ """
+ results = []
+
+ if not os.path.exists(self.data_dir):
+ print(f"⚠️ 数据目录不存在: {self.data_dir}")
+ return results
+
+ print("\n正在提取计算结果...")
+
+ # 遍历阴离子目录
+ for anion_key in os.listdir(self.data_dir):
+ anion_dir = os.path.join(self.data_dir, anion_key)
+ if not os.path.isdir(anion_dir):
+ continue
+
+ # 遍历结构目录
+ for struct_name in os.listdir(anion_dir):
+ struct_dir = os.path.join(anion_dir, struct_name)
+ if not os.path.isdir(struct_dir):
+ continue
+
+ # 查找 log.txt
+ log_path = os.path.join(struct_dir, "log.txt")
+
+ # 提取参数
+ perc, min_d, max_node = self.extract_from_log(log_path)
+
+ result = StructureResult(
+ structure_name=struct_name,
+ anion_type=anion_key,
+ work_dir=struct_dir,
+ percolation_diameter=perc,
+ min_d=min_d,
+ max_node_length=max_node
+ )
+
+ results.append(result)
+
+ print(f" 共处理 {len(results)} 个结构")
+ return results
+
+ def apply_filter(
+ self,
+ results: List[StructureResult],
+ criteria: FilterCriteria
+ ) -> List[StructureResult]:
+ """
+ 应用筛选条件
+
+ Args:
+ results: 结构结果列表
+ criteria: 筛选条件
+
+ Returns:
+ 更新后的结果列表(包含筛选状态)
+ """
+ print("\n应用筛选条件...")
+ print(f" 最小渗透直径: {criteria.min_percolation_diameter} Å")
+ print(f" 最小 d 值: {criteria.min_d_value}")
+ print(f" 最大节点长度: {criteria.max_node_length} Å")
+
+ passed_count = 0
+
+ for result in results:
+ # 检查是否有有效数据
+ if result.percolation_diameter is None or result.min_d is None:
+ result.passed_filter = False
+ result.filter_reason = "数据缺失"
+ continue
+
+ # 检查渗透直径
+ if result.percolation_diameter < criteria.min_percolation_diameter:
+ result.passed_filter = False
+ result.filter_reason = f"渗透直径 {result.percolation_diameter:.3f} < {criteria.min_percolation_diameter}"
+ continue
+
+ # 检查 d 值
+ if result.min_d < criteria.min_d_value:
+ result.passed_filter = False
+ result.filter_reason = f"d 值 {result.min_d:.3f} < {criteria.min_d_value}"
+ continue
+
+ # 检查节点长度(如果有数据)
+ if result.max_node_length is not None:
+ if result.max_node_length > criteria.max_node_length:
+ result.passed_filter = False
+ result.filter_reason = f"节点长度 {result.max_node_length:.3f} > {criteria.max_node_length}"
+ continue
+
+ # 通过所有筛选
+ result.passed_filter = True
+ result.filter_reason = "通过"
+ passed_count += 1
+
+ print(f" 通过筛选: {passed_count}/{len(results)}")
+ return results
+
+ def save_summary_csv(
+ self,
+ results: List[StructureResult],
+ output_path: str = None
+ ) -> str:
+ """
+ 保存汇总 CSV 文件
+
+ Args:
+ results: 结构结果列表
+ output_path: 输出路径(默认 workspace/results/summary.csv)
+
+ Returns:
+ CSV 文件路径
+ """
+ if output_path is None:
+ os.makedirs(self.output_dir, exist_ok=True)
+ output_path = os.path.join(self.output_dir, "summary.csv")
+
+ # 构建数据
+ data = []
+ for r in results:
+ data.append({
+ 'Structure': r.structure_name,
+ 'Anion_Type': r.anion_type,
+ 'Percolation_Diameter_A': r.percolation_diameter,
+ 'Min_d': r.min_d,
+ 'Max_Node_Length_A': r.max_node_length,
+ 'Passed_Filter': r.passed_filter,
+ 'Filter_Reason': r.filter_reason
+ })
+
+ df = pd.DataFrame(data)
+
+ # 按阴离子类型和结构名排序
+ df = df.sort_values(['Anion_Type', 'Structure'])
+
+ # 保存
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
+ df.to_csv(output_path, index=False)
+
+ print(f"\n汇总 CSV 已保存: {output_path}")
+ return output_path
+
+ def save_anion_csv(
+ self,
+ results: List[StructureResult],
+ output_dir: str = None
+ ) -> List[str]:
+ """
+ 按阴离子类型分别保存 CSV 文件
+
+ Args:
+ results: 结构结果列表
+ output_dir: 输出目录
+
+ Returns:
+ 生成的 CSV 文件路径列表
+ """
+ if output_dir is None:
+ output_dir = self.output_dir
+
+ # 按阴离子类型分组
+ anion_groups: Dict[str, List[StructureResult]] = {}
+ for r in results:
+ if r.anion_type not in anion_groups:
+ anion_groups[r.anion_type] = []
+ anion_groups[r.anion_type].append(r)
+
+ csv_files = []
+
+ for anion_type, group_results in anion_groups.items():
+ # 构建数据
+ data = []
+ for r in group_results:
+ data.append({
+ 'Structure': r.structure_name,
+ 'Percolation_Diameter_A': r.percolation_diameter,
+ 'Min_d': r.min_d,
+ 'Max_Node_Length_A': r.max_node_length,
+ 'Passed_Filter': r.passed_filter,
+ 'Filter_Reason': r.filter_reason
+ })
+
+ df = pd.DataFrame(data)
+ df = df.sort_values('Structure')
+
+ # 保存到对应目录
+ anion_output_dir = os.path.join(output_dir, anion_type)
+ os.makedirs(anion_output_dir, exist_ok=True)
+
+ csv_path = os.path.join(anion_output_dir, f"{anion_type}.csv")
+ df.to_csv(csv_path, index=False)
+ csv_files.append(csv_path)
+
+ print(f" {anion_type}: {len(group_results)} 个结构 -> {csv_path}")
+
+ return csv_files
+
+ def copy_passed_structures(
+ self,
+ results: List[StructureResult],
+ output_dir: str = None
+ ) -> int:
+ """
+ 将通过筛选的结构复制到新文件夹
+
+ Args:
+ results: 结构结果列表
+ output_dir: 输出目录(默认 workspace/passed)
+
+ Returns:
+ 复制的结构数量
+ """
+ if output_dir is None:
+ output_dir = os.path.join(self.workspace_path, "passed")
+
+ passed_results = [r for r in results if r.passed_filter]
+
+ if not passed_results:
+ print("\n没有通过筛选的结构")
+ return 0
+
+ print(f"\n正在复制 {len(passed_results)} 个通过筛选的结构...")
+
+ copied = 0
+ for r in passed_results:
+ # 目标目录:passed/阴离子类型/结构名/
+ dst_dir = os.path.join(output_dir, r.anion_type, r.structure_name)
+
+ try:
+ # 如果目标已存在,先删除
+ if os.path.exists(dst_dir):
+ shutil.rmtree(dst_dir)
+
+ # 复制整个目录
+ shutil.copytree(r.work_dir, dst_dir)
+ copied += 1
+
+ except Exception as e:
+ print(f" ⚠️ 复制失败 {r.structure_name}: {e}")
+
+ print(f" 已复制 {copied} 个结构到: {output_dir}")
+ return copied
+
+ def process_and_filter(
+ self,
+ criteria: FilterCriteria = None,
+ save_csv: bool = True,
+ copy_passed: bool = True
+ ) -> Tuple[List[StructureResult], Dict]:
+ """
+ 完整的处理流程:提取数据 -> 筛选 -> 保存 CSV -> 复制通过的结构
+
+ Args:
+ criteria: 筛选条件(如果为 None,则不筛选)
+ save_csv: 是否保存 CSV
+ copy_passed: 是否复制通过筛选的结构
+
+ Returns:
+ (结果列表, 统计信息字典)
+ """
+ # 1. 提取所有结构的计算结果
+ results = self.process_all_structures()
+
+ if not results:
+ return results, {'total': 0, 'passed': 0, 'failed': 0}
+
+ # 2. 应用筛选条件
+ if criteria is not None:
+ results = self.apply_filter(results, criteria)
+
+ # 3. 保存 CSV
+ if save_csv:
+ print("\n保存结果 CSV...")
+ self.save_summary_csv(results)
+ self.save_anion_csv(results)
+
+ # 4. 复制通过筛选的结构
+ if copy_passed and criteria is not None:
+ self.copy_passed_structures(results)
+
+ # 统计
+ stats = {
+ 'total': len(results),
+ 'passed': sum(1 for r in results if r.passed_filter),
+ 'failed': sum(1 for r in results if not r.passed_filter),
+ 'missing_data': sum(1 for r in results if r.filter_reason == "数据缺失")
+ }
+
+ return results, stats
+
+ def print_summary(self, results: List[StructureResult], stats: Dict):
+ """打印结果摘要"""
+ print("\n" + "=" * 60)
+ print("【计算结果摘要】")
+ print("=" * 60)
+ print(f" 总结构数: {stats['total']}")
+ print(f" 通过筛选: {stats['passed']}")
+ print(f" 未通过筛选: {stats['failed']}")
+ print(f" 数据缺失: {stats.get('missing_data', 0)}")
+
+ # 按阴离子类型统计
+ anion_stats: Dict[str, Dict] = {}
+ for r in results:
+ if r.anion_type not in anion_stats:
+ anion_stats[r.anion_type] = {'total': 0, 'passed': 0}
+ anion_stats[r.anion_type]['total'] += 1
+ if r.passed_filter:
+ anion_stats[r.anion_type]['passed'] += 1
+
+ print("\n 按阴离子类型:")
+ for anion, s in sorted(anion_stats.items()):
+ print(f" {anion}: {s['passed']}/{s['total']} 通过")
+
+ print("=" * 60)
diff --git a/src/computation/workspace_manager.py b/src/computation/workspace_manager.py
new file mode 100644
index 0000000..7c8d887
--- /dev/null
+++ b/src/computation/workspace_manager.py
@@ -0,0 +1,288 @@
+"""
+工作区管理器:管理计算工作区的创建和软链接
+"""
+import os
+import shutil
+from pathlib import Path
+from typing import Dict, List, Optional, Set, Tuple
+from dataclasses import dataclass, field
+
+
+@dataclass
+class WorkspaceInfo:
+ """工作区信息"""
+ workspace_path: str
+ data_dir: str # workspace/data
+ tool_dir: str # tool 目录
+ target_cation: str
+ target_anions: Set[str]
+
+ # 统计信息
+ total_structures: int = 0
+ anion_counts: Dict[str, int] = field(default_factory=dict)
+ linked_structures: int = 0 # 已创建软链接的结构数
+
+
+class WorkspaceManager:
+ """
+ 工作区管理器
+
+ 负责:
+ 1. 检测现有工作区
+ 2. 创建软链接(yaml 配置文件和计算脚本放在每个结构目录下)
+ 3. 准备计算任务
+ """
+
+ # 支持的阴离子及其配置文件
+ SUPPORTED_ANIONS = {'O', 'S', 'Cl', 'Br'}
+
+ def __init__(
+ self,
+ workspace_path: str = "workspace",
+ tool_dir: str = "tool",
+ target_cation: str = "Li"
+ ):
+ """
+ 初始化工作区管理器
+
+ Args:
+ workspace_path: 工作区根目录
+ tool_dir: 工具目录(包含 yaml 配置和计算脚本)
+ target_cation: 目标阳离子
+ """
+ self.workspace_path = os.path.abspath(workspace_path)
+ self.tool_dir = os.path.abspath(tool_dir)
+ self.target_cation = target_cation
+
+ # 数据目录
+ self.data_dir = os.path.join(self.workspace_path, "data")
+
+ def check_existing_workspace(self) -> Optional[WorkspaceInfo]:
+ """
+ 检查现有工作区
+
+ Returns:
+ WorkspaceInfo 如果存在,否则 None
+ """
+ if not os.path.exists(self.data_dir):
+ return None
+
+ # 扫描数据目录
+ anion_counts = {}
+ total = 0
+ linked = 0
+
+ for item in os.listdir(self.data_dir):
+ item_path = os.path.join(self.data_dir, item)
+ if os.path.isdir(item_path):
+ # 可能是阴离子目录(如 O, S, O+S)
+ # 统计其中的结构数量
+ count = 0
+ for sub_item in os.listdir(item_path):
+ sub_path = os.path.join(item_path, sub_item)
+ if os.path.isdir(sub_path):
+ # 检查是否包含 CIF 文件
+ cif_files = [f for f in os.listdir(sub_path) if f.endswith('.cif')]
+ if cif_files:
+ count += 1
+ # 检查是否已有软链接
+ yaml_files = [f for f in os.listdir(sub_path) if f.endswith('.yaml')]
+ if yaml_files:
+ linked += 1
+
+ if count > 0:
+ anion_counts[item] = count
+ total += count
+
+ if total == 0:
+ return None
+
+ return WorkspaceInfo(
+ workspace_path=self.workspace_path,
+ data_dir=self.data_dir,
+ tool_dir=self.tool_dir,
+ target_cation=self.target_cation,
+ target_anions=set(anion_counts.keys()),
+ total_structures=total,
+ anion_counts=anion_counts,
+ linked_structures=linked
+ )
+
+ def setup_workspace(
+ self,
+ target_anions: Set[str] = None,
+ force_relink: bool = False
+ ) -> WorkspaceInfo:
+ """
+ 设置工作区:在每个结构目录下创建软链接
+
+ 软链接规则:
+ - yaml 文件:使用与阴离子目录同名的 yaml(如 O 目录用 O.yaml,Cl+O 目录用 Cl+O.yaml)
+ - python 脚本:analyze_voronoi_nodes.py
+
+ Args:
+ target_anions: 目标阴离子集合
+ force_relink: 是否强制重新创建软链接
+
+ Returns:
+ WorkspaceInfo
+ """
+ if target_anions is None:
+ target_anions = self.SUPPORTED_ANIONS
+
+ # 确保数据目录存在
+ if not os.path.exists(self.data_dir):
+ raise FileNotFoundError(f"数据目录不存在: {self.data_dir}")
+
+ # 获取计算脚本路径
+ analyze_script = os.path.join(self.tool_dir, "analyze_voronoi_nodes.py")
+ if not os.path.exists(analyze_script):
+ raise FileNotFoundError(f"计算脚本不存在: {analyze_script}")
+
+ anion_counts = {}
+ total = 0
+ linked = 0
+
+ print("\n正在设置工作区软链接...")
+
+ # 遍历数据目录中的阴离子子目录
+ for anion_key in os.listdir(self.data_dir):
+ anion_dir = os.path.join(self.data_dir, anion_key)
+ if not os.path.isdir(anion_dir):
+ continue
+
+ # 确定使用哪个 yaml 配置文件
+ # 使用与阴离子目录同名的 yaml 文件(如 O.yaml, Cl+O.yaml)
+ yaml_name = f"{anion_key}.yaml"
+ yaml_source = os.path.join(self.tool_dir, self.target_cation, yaml_name)
+
+ if not os.path.exists(yaml_source):
+ print(f" ⚠️ 配置文件不存在: {yaml_source}")
+ continue
+
+ # 统计并处理该阴离子目录下的所有结构
+ count = 0
+ for struct_name in os.listdir(anion_dir):
+ struct_dir = os.path.join(anion_dir, struct_name)
+
+ if not os.path.isdir(struct_dir):
+ continue
+
+ # 检查是否包含 CIF 文件
+ cif_files = [f for f in os.listdir(struct_dir) if f.endswith('.cif')]
+ if not cif_files:
+ continue
+
+ count += 1
+
+ # 在结构目录下创建软链接
+ yaml_link = os.path.join(struct_dir, yaml_name)
+ script_link = os.path.join(struct_dir, "analyze_voronoi_nodes.py")
+
+ # 创建 yaml 软链接
+ if os.path.exists(yaml_link) or os.path.islink(yaml_link):
+ if force_relink:
+ os.remove(yaml_link)
+ os.symlink(yaml_source, yaml_link)
+ linked += 1
+ else:
+ os.symlink(yaml_source, yaml_link)
+ linked += 1
+
+ # 创建计算脚本软链接
+ if os.path.exists(script_link) or os.path.islink(script_link):
+ if force_relink:
+ os.remove(script_link)
+ os.symlink(analyze_script, script_link)
+ else:
+ os.symlink(analyze_script, script_link)
+
+ if count > 0:
+ anion_counts[anion_key] = count
+ total += count
+ print(f" ✓ {anion_key}: {count} 个结构, 配置 -> {yaml_name}")
+
+ print(f"\n 总计: {total} 个结构, 新建软链接: {linked}")
+
+ return WorkspaceInfo(
+ workspace_path=self.workspace_path,
+ data_dir=self.data_dir,
+ tool_dir=self.tool_dir,
+ target_cation=self.target_cation,
+ target_anions=set(anion_counts.keys()),
+ total_structures=total,
+ anion_counts=anion_counts,
+ linked_structures=linked
+ )
+
+ def get_computation_tasks(
+ self,
+ workspace_info: WorkspaceInfo = None
+ ) -> List[Dict]:
+ """
+ 获取所有计算任务
+
+ Returns:
+ 任务列表,每个任务包含:
+ - cif_path: CIF 文件路径
+ - yaml_name: YAML 配置文件名(如 O.yaml)
+ - work_dir: 工作目录(结构目录)
+ - anion_type: 阴离子类型
+ - structure_name: 结构名称
+ """
+ if workspace_info is None:
+ workspace_info = self.check_existing_workspace()
+
+ if workspace_info is None:
+ return []
+
+ tasks = []
+
+ for anion_key in workspace_info.anion_counts.keys():
+ anion_dir = os.path.join(self.data_dir, anion_key)
+ yaml_name = f"{anion_key}.yaml"
+
+ # 遍历该阴离子目录下的所有结构
+ for struct_name in os.listdir(anion_dir):
+ struct_dir = os.path.join(anion_dir, struct_name)
+
+ if not os.path.isdir(struct_dir):
+ continue
+
+ # 查找 CIF 文件
+ cif_files = [f for f in os.listdir(struct_dir) if f.endswith('.cif')]
+
+ # 检查是否有 yaml 软链接
+ yaml_path = os.path.join(struct_dir, yaml_name)
+ if not os.path.exists(yaml_path):
+ continue
+
+ for cif_file in cif_files:
+ cif_path = os.path.join(struct_dir, cif_file)
+
+ tasks.append({
+ 'cif_path': cif_path,
+ 'yaml_name': yaml_name,
+ 'work_dir': struct_dir,
+ 'anion_type': anion_key,
+ 'structure_name': struct_name,
+ 'cif_name': cif_file
+ })
+
+ return tasks
+
+ def print_workspace_summary(self, workspace_info: WorkspaceInfo):
+ """打印工作区摘要"""
+ print("\n" + "=" * 60)
+ print("【工作区摘要】")
+ print("=" * 60)
+ print(f" 工作区路径: {workspace_info.workspace_path}")
+ print(f" 数据目录: {workspace_info.data_dir}")
+ print(f" 目标阳离子: {workspace_info.target_cation}")
+ print(f" 总结构数: {workspace_info.total_structures}")
+ print(f" 已配置软链接: {workspace_info.linked_structures}")
+ print()
+ print(" 阴离子分布:")
+ for anion, count in sorted(workspace_info.anion_counts.items()):
+ print(f" - {anion}: {count} 个结构")
+ print("=" * 60)
diff --git a/src/computation/zeo_executor.py b/src/computation/zeo_executor.py
new file mode 100644
index 0000000..48c2061
--- /dev/null
+++ b/src/computation/zeo_executor.py
@@ -0,0 +1,446 @@
+"""
+Zeo++ 计算执行器:使用 SLURM 作业数组高效调度大量计算任务
+"""
+import os
+import subprocess
+import time
+import json
+import tempfile
+from typing import List, Dict, Optional, Callable, Any
+from dataclasses import dataclass, field
+from enum import Enum
+import threading
+
+from ..core.progress import ProgressManager
+
+
+@dataclass
+class ZeoConfig:
+ """Zeo++ 计算配置"""
+ # 环境配置
+ conda_env: str = "/cluster/home/koko125/anaconda3/envs/zeo"
+
+ # SLURM 配置
+ partition: str = "cpu"
+ time_limit: str = "2:00:00" # 单个任务时间限制
+ memory_per_task: str = "4G"
+
+ # 作业数组配置
+ max_array_size: int = 1000 # SLURM 作业数组最大大小
+ max_concurrent: int = 50 # 最大并发任务数
+
+ # 轮询配置
+ poll_interval: float = 5.0 # 状态检查间隔(秒)
+
+ # 过滤器配置
+ filters: List[str] = field(default_factory=lambda: [
+ "Ordered", "PropOxi", "VoroPerco", "Coulomb", "VoroBV", "VoroInfo", "MergeSite"
+ ])
+
+
+@dataclass
+class ZeoTaskResult:
+ """单个任务结果"""
+ task_id: int
+ structure_name: str
+ cif_path: str
+ success: bool
+ output_files: List[str] = field(default_factory=list)
+ error_message: str = ""
+ duration: float = 0.0
+
+
+class ZeoExecutor:
+ """
+ Zeo++ 计算执行器
+
+ 使用 SLURM 作业数组高效调度大量 Voronoi 分析任务
+ """
+
+ def __init__(self, config: ZeoConfig = None):
+ self.config = config or ZeoConfig()
+ self.progress_manager = None
+ self._stop_event = threading.Event()
+
+ def run_batch(
+ self,
+ tasks: List[Dict],
+ output_dir: str = None,
+ desc: str = "Zeo++ 计算"
+ ) -> List[ZeoTaskResult]:
+ """
+ 批量执行 Zeo++ 计算
+
+ Args:
+ tasks: 任务列表,每个任务包含 cif_path, yaml_path, work_dir 等
+ output_dir: SLURM 日志输出目录
+ desc: 进度条描述
+
+ Returns:
+ ZeoTaskResult 列表
+ """
+ if not tasks:
+ print("⚠️ 没有任务需要执行")
+ return []
+
+ total = len(tasks)
+
+ # 创建输出目录
+ if output_dir is None:
+ output_dir = os.path.join(os.getcwd(), "slurm_logs")
+ os.makedirs(output_dir, exist_ok=True)
+
+ print(f"\n{'='*60}")
+ print(f"【Zeo++ 批量计算】")
+ print(f"{'='*60}")
+ print(f" 总任务数: {total}")
+ print(f" Conda环境: {self.config.conda_env}")
+ print(f" SLURM分区: {self.config.partition}")
+ print(f" 最大并发: {self.config.max_concurrent}")
+ print(f" 日志目录: {output_dir}")
+ print(f"{'='*60}\n")
+
+ # 保存任务列表到文件
+ tasks_file = os.path.join(output_dir, "tasks.json")
+ with open(tasks_file, 'w') as f:
+ json.dump(tasks, f, indent=2)
+
+ # 生成并提交作业数组
+ if total <= self.config.max_array_size:
+ # 单个作业数组
+ return self._submit_array_job(tasks, output_dir, desc)
+ else:
+ # 分批提交多个作业数组
+ return self._submit_batched_arrays(tasks, output_dir, desc)
+
+ def _submit_array_job(
+ self,
+ tasks: List[Dict],
+ output_dir: str,
+ desc: str
+ ) -> List[ZeoTaskResult]:
+ """提交单个作业数组"""
+ total = len(tasks)
+
+ # 保存任务列表
+ tasks_file = os.path.join(output_dir, "tasks.json")
+ with open(tasks_file, 'w') as f:
+ json.dump(tasks, f, indent=2)
+
+ # 生成作业脚本
+ script_content = self._generate_array_script(
+ tasks_file=tasks_file,
+ output_dir=output_dir,
+ array_range=f"0-{total-1}%{self.config.max_concurrent}"
+ )
+
+ script_path = os.path.join(output_dir, "submit_array.sh")
+ with open(script_path, 'w') as f:
+ f.write(script_content)
+ os.chmod(script_path, 0o755)
+
+ print(f"生成作业脚本: {script_path}")
+
+ # 提交作业
+ result = subprocess.run(
+ ['sbatch', script_path],
+ capture_output=True,
+ text=True
+ )
+
+ if result.returncode != 0:
+ print(f"❌ 作业提交失败: {result.stderr}")
+ return [ZeoTaskResult(
+ task_id=i,
+ structure_name=t.get('structure_name', ''),
+ cif_path=t.get('cif_path', ''),
+ success=False,
+ error_message=f"提交失败: {result.stderr}"
+ ) for i, t in enumerate(tasks)]
+
+ # 提取作业 ID
+ job_id = result.stdout.strip().split()[-1]
+ print(f"✓ 作业已提交: {job_id}")
+ print(f" 作业数组范围: 0-{total-1}")
+ print(f" 最大并发: {self.config.max_concurrent}")
+
+ # 监控作业进度
+ return self._monitor_array_job(job_id, tasks, output_dir, desc)
+
+ def _submit_batched_arrays(
+ self,
+ tasks: List[Dict],
+ output_dir: str,
+ desc: str
+ ) -> List[ZeoTaskResult]:
+ """分批提交多个作业数组"""
+ total = len(tasks)
+ batch_size = self.config.max_array_size
+ num_batches = (total + batch_size - 1) // batch_size
+
+ print(f"任务数超过作业数组限制,分 {num_batches} 批提交")
+
+ all_results = []
+
+ for batch_idx in range(num_batches):
+ start_idx = batch_idx * batch_size
+ end_idx = min(start_idx + batch_size, total)
+ batch_tasks = tasks[start_idx:end_idx]
+
+ batch_output_dir = os.path.join(output_dir, f"batch_{batch_idx}")
+ os.makedirs(batch_output_dir, exist_ok=True)
+
+ print(f"\n--- 批次 {batch_idx + 1}/{num_batches} ---")
+ print(f"任务范围: {start_idx} - {end_idx - 1}")
+
+ batch_results = self._submit_array_job(
+ batch_tasks,
+ batch_output_dir,
+ f"{desc} (批次 {batch_idx + 1}/{num_batches})"
+ )
+
+ # 调整任务 ID
+ for r in batch_results:
+ r.task_id += start_idx
+
+ all_results.extend(batch_results)
+
+ return all_results
+
+ def _generate_array_script(
+ self,
+ tasks_file: str,
+ output_dir: str,
+ array_range: str
+ ) -> str:
+ """生成 SLURM 作业数组脚本"""
+
+ # 获取项目根目录
+ project_root = os.getcwd()
+
+ script = f"""#!/bin/bash
+#SBATCH --job-name=zeo_array
+#SBATCH --partition={self.config.partition}
+#SBATCH --array={array_range}
+#SBATCH --ntasks=1
+#SBATCH --cpus-per-task=1
+#SBATCH --mem={self.config.memory_per_task}
+#SBATCH --time={self.config.time_limit}
+#SBATCH --output={output_dir}/task_%a.out
+#SBATCH --error={output_dir}/task_%a.err
+
+# ============================================
+# Zeo++ Voronoi 分析 - 作业数组
+# ============================================
+
+echo "===== 任务信息 ====="
+echo "作业ID: $SLURM_JOB_ID"
+echo "数组任务ID: $SLURM_ARRAY_TASK_ID"
+echo "节点: $SLURM_NODELIST"
+echo "开始时间: $(date)"
+echo "===================="
+
+# ============ 环境初始化 ============
+# 加载 bashrc
+if [ -f ~/.bashrc ]; then
+ source ~/.bashrc
+fi
+
+# 初始化 Conda
+if [ -f ~/anaconda3/etc/profile.d/conda.sh ]; then
+ source ~/anaconda3/etc/profile.d/conda.sh
+elif [ -f /opt/anaconda3/etc/profile.d/conda.sh ]; then
+ source /opt/anaconda3/etc/profile.d/conda.sh
+fi
+
+# 激活 Zeo++ 环境
+conda activate {self.config.conda_env}
+
+echo ""
+echo "===== 环境检查 ====="
+echo "Conda环境: $CONDA_DEFAULT_ENV"
+echo "Python路径: $(which python)"
+echo "===================="
+echo ""
+
+# ============ 读取任务信息 ============
+TASKS_FILE="{tasks_file}"
+TASK_ID=$SLURM_ARRAY_TASK_ID
+
+# 使用 Python 解析任务
+TASK_INFO=$(python3 -c "
+import json
+with open('$TASKS_FILE', 'r') as f:
+ tasks = json.load(f)
+if $TASK_ID < len(tasks):
+ task = tasks[$TASK_ID]
+ print(task['work_dir'])
+ print(task['yaml_name'])
+else:
+ print('ERROR')
+")
+
+WORK_DIR=$(echo "$TASK_INFO" | sed -n '1p')
+YAML_NAME=$(echo "$TASK_INFO" | sed -n '2p')
+
+if [ "$WORK_DIR" == "ERROR" ]; then
+ echo "错误: 任务ID $TASK_ID 超出范围"
+ exit 1
+fi
+
+echo "工作目录: $WORK_DIR"
+echo "配置文件: $YAML_NAME"
+echo ""
+
+# ============ 执行计算 ============
+cd "$WORK_DIR"
+
+echo "开始 Voronoi 分析..."
+# 软链接已在工作目录下,直接使用相对路径
+# 将输出重定向到 log.txt 以便后续提取结果
+python analyze_voronoi_nodes.py *.cif -i "$YAML_NAME" > log.txt 2>&1
+
+EXIT_CODE=$?
+
+# 显示日志内容(用于调试)
+echo ""
+echo "===== 计算日志 ====="
+cat log.txt
+echo "===================="
+
+# ============ 完成 ============
+echo ""
+echo "===== 任务完成 ====="
+echo "结束时间: $(date)"
+echo "退出代码: $EXIT_CODE"
+
+# 写入状态文件
+if [ $EXIT_CODE -eq 0 ]; then
+ echo "SUCCESS" > "{output_dir}/status_$TASK_ID.txt"
+else
+ echo "FAILED" > "{output_dir}/status_$TASK_ID.txt"
+fi
+
+echo "===================="
+exit $EXIT_CODE
+"""
+ return script
+
+ def _monitor_array_job(
+ self,
+ job_id: str,
+ tasks: List[Dict],
+ output_dir: str,
+ desc: str
+ ) -> List[ZeoTaskResult]:
+ """监控作业数组进度"""
+ total = len(tasks)
+
+ self.progress_manager = ProgressManager(total, desc)
+ self.progress_manager.start()
+
+ results = [None] * total
+ completed = set()
+
+ print(f"\n监控作业进度 (每 {self.config.poll_interval} 秒检查一次)...")
+ print("按 Ctrl+C 可中断监控(作业将继续在后台运行)\n")
+
+ try:
+ while len(completed) < total:
+ time.sleep(self.config.poll_interval)
+
+ # 检查状态文件
+ for i in range(total):
+ if i in completed:
+ continue
+
+ status_file = os.path.join(output_dir, f"status_{i}.txt")
+ if os.path.exists(status_file):
+ with open(status_file, 'r') as f:
+ status = f.read().strip()
+
+ task = tasks[i]
+ success = (status == "SUCCESS")
+
+ # 收集输出文件
+ output_files = []
+ if success:
+ work_dir = task['work_dir']
+ for f in os.listdir(work_dir):
+ if f.endswith(('.cif', '.csv')) and f != task['cif_name']:
+ output_files.append(os.path.join(work_dir, f))
+
+ results[i] = ZeoTaskResult(
+ task_id=i,
+ structure_name=task.get('structure_name', ''),
+ cif_path=task.get('cif_path', ''),
+ success=success,
+ output_files=output_files
+ )
+
+ completed.add(i)
+ self.progress_manager.update(success=success)
+ self.progress_manager.display()
+
+ # 检查作业是否还在运行
+ if not self._is_job_running(job_id) and len(completed) < total:
+ # 作业已结束但有任务未完成
+ print(f"\n⚠️ 作业已结束,但有 {total - len(completed)} 个任务未完成")
+ break
+
+ except KeyboardInterrupt:
+ print("\n\n⚠️ 监控已中断,作业将继续在后台运行")
+ print(f" 可使用 'squeue -j {job_id}' 查看作业状态")
+ print(f" 可使用 'scancel {job_id}' 取消作业")
+
+ self.progress_manager.finish()
+
+ # 填充未完成的任务
+ for i in range(total):
+ if results[i] is None:
+ task = tasks[i]
+ results[i] = ZeoTaskResult(
+ task_id=i,
+ structure_name=task.get('structure_name', ''),
+ cif_path=task.get('cif_path', ''),
+ success=False,
+ error_message="任务未完成或状态未知"
+ )
+
+ return results
+
+ def _is_job_running(self, job_id: str) -> bool:
+ """检查作业是否还在运行"""
+ try:
+ result = subprocess.run(
+ ['squeue', '-j', job_id, '-h'],
+ capture_output=True,
+ text=True,
+ timeout=10
+ )
+ return bool(result.stdout.strip())
+ except Exception:
+ return False
+
+ def print_results_summary(self, results: List[ZeoTaskResult]):
+ """打印结果摘要"""
+ total = len(results)
+ success = sum(1 for r in results if r.success)
+ failed = total - success
+
+ print("\n" + "=" * 60)
+ print("【计算结果摘要】")
+ print("=" * 60)
+ print(f" 总任务数: {total}")
+ print(f" 成功: {success} ({100*success/total:.1f}%)")
+ print(f" 失败: {failed} ({100*failed/total:.1f}%)")
+
+ if failed > 0 and failed <= 10:
+ print("\n 失败的任务:")
+ for r in results:
+ if not r.success:
+ print(f" - {r.structure_name}: {r.error_message}")
+ elif failed > 10:
+ print(f"\n 失败任务过多,请检查日志文件")
+
+ print("=" * 60)
diff --git a/src/core/__init__.py b/src/core/__init__.py
index e69de29..45e21bd 100644
--- a/src/core/__init__.py
+++ b/src/core/__init__.py
@@ -0,0 +1,18 @@
+"""
+核心模块:调度器、执行器和进度管理
+"""
+from .scheduler import ParallelScheduler, ResourceConfig, ExecutionMode as SchedulerMode
+from .executor import TaskExecutor, ExecutorConfig, ExecutionMode, TaskResult, create_executor
+from .progress import ProgressManager
+
+__all__ = [
+ 'ParallelScheduler',
+ 'ResourceConfig',
+ 'SchedulerMode',
+ 'TaskExecutor',
+ 'ExecutorConfig',
+ 'ExecutionMode',
+ 'TaskResult',
+ 'create_executor',
+ 'ProgressManager',
+]
diff --git a/src/core/executor.py b/src/core/executor.py
new file mode 100644
index 0000000..e56d6a2
--- /dev/null
+++ b/src/core/executor.py
@@ -0,0 +1,431 @@
+"""
+任务执行器:支持本地执行和 SLURM 直接提交
+不生成脚本文件,直接在 Python 中管理任务
+"""
+import os
+import subprocess
+import time
+import json
+from typing import List, Callable, Any, Optional, Dict, Tuple
+from multiprocessing import Pool, cpu_count
+from dataclasses import dataclass, field
+from enum import Enum
+from concurrent.futures import ThreadPoolExecutor, as_completed
+import threading
+
+from .progress import ProgressManager
+
+
+class ExecutionMode(Enum):
+ """执行模式"""
+ LOCAL = "local" # 本地多进程
+ SLURM_DIRECT = "slurm" # SLURM 直接提交(不生成脚本)
+
+
+@dataclass
+class ExecutorConfig:
+ """执行器配置"""
+ mode: ExecutionMode = ExecutionMode.LOCAL
+ max_workers: int = 4
+ conda_env: str = "/cluster/home/koko125/anaconda3/envs/screen"
+ partition: str = "cpu"
+ time_limit: str = "7-00:00:00"
+ memory_per_task: str = "4G"
+
+ # SLURM 相关
+ poll_interval: float = 2.0 # 轮询间隔(秒)
+ max_concurrent_jobs: int = 50 # 最大并发作业数
+
+
+@dataclass
+class TaskResult:
+ """任务结果"""
+ task_id: Any
+ success: bool
+ result: Any = None
+ error: str = None
+ duration: float = 0.0
+
+
+class TaskExecutor:
+ """
+ 任务执行器
+
+ 支持两种模式:
+ 1. LOCAL: 本地多进程执行
+ 2. SLURM_DIRECT: 直接提交 SLURM 作业,实时监控进度
+ """
+
+ def __init__(self, config: ExecutorConfig = None):
+ self.config = config or ExecutorConfig()
+ self.progress_manager = None
+ self._stop_event = threading.Event()
+
+ @staticmethod
+ def detect_environment() -> Dict[str, Any]:
+ """检测运行环境"""
+ env_info = {
+ 'hostname': os.uname().nodename,
+ 'total_cores': cpu_count(),
+ 'has_slurm': False,
+ 'slurm_partitions': [],
+ 'conda_env': os.environ.get('CONDA_PREFIX', ''),
+ }
+
+ # 检测 SLURM
+ try:
+ result = subprocess.run(
+ ['sinfo', '-h', '-o', '%P %a %c %D'],
+ capture_output=True, text=True, timeout=5
+ )
+ if result.returncode == 0:
+ env_info['has_slurm'] = True
+ lines = result.stdout.strip().split('\n')
+ for line in lines:
+ parts = line.split()
+ if len(parts) >= 4:
+ partition = parts[0].rstrip('*')
+ avail = parts[1]
+ if avail == 'up':
+ env_info['slurm_partitions'].append(partition)
+ except Exception:
+ pass
+
+ return env_info
+
+ def run(
+ self,
+ tasks: List[Any],
+ worker_func: Callable,
+ desc: str = "Processing"
+ ) -> List[TaskResult]:
+ """
+ 执行任务
+
+ Args:
+ tasks: 任务列表
+ worker_func: 工作函数,接收单个任务,返回结果
+ desc: 进度条描述
+
+ Returns:
+ TaskResult 列表
+ """
+ if self.config.mode == ExecutionMode.LOCAL:
+ return self._run_local(tasks, worker_func, desc)
+ elif self.config.mode == ExecutionMode.SLURM_DIRECT:
+ return self._run_slurm_direct(tasks, worker_func, desc)
+ else:
+ raise ValueError(f"不支持的执行模式: {self.config.mode}")
+
+ def _run_local(
+ self,
+ tasks: List[Any],
+ worker_func: Callable,
+ desc: str
+ ) -> List[TaskResult]:
+ """本地多进程执行"""
+ total = len(tasks)
+ num_workers = min(self.config.max_workers, total)
+
+ print(f"\n{'='*60}")
+ print(f"本地执行配置:")
+ print(f" 总任务数: {total}")
+ print(f" Worker数: {num_workers}")
+ print(f"{'='*60}\n")
+
+ self.progress_manager = ProgressManager(total, desc)
+ self.progress_manager.start()
+
+ results = []
+
+ if num_workers == 1:
+ # 单进程执行
+ for i, task in enumerate(tasks):
+ start_time = time.time()
+ try:
+ result = worker_func(task)
+ duration = time.time() - start_time
+ results.append(TaskResult(
+ task_id=i,
+ success=True,
+ result=result,
+ duration=duration
+ ))
+ self.progress_manager.update(success=True)
+ except Exception as e:
+ duration = time.time() - start_time
+ results.append(TaskResult(
+ task_id=i,
+ success=False,
+ error=str(e),
+ duration=duration
+ ))
+ self.progress_manager.update(success=False)
+ self.progress_manager.display()
+ else:
+ # 多进程执行
+ with Pool(processes=num_workers) as pool:
+ for i, result in enumerate(pool.imap_unordered(worker_func, tasks)):
+ if result is not None:
+ results.append(TaskResult(
+ task_id=i,
+ success=True,
+ result=result
+ ))
+ self.progress_manager.update(success=True)
+ else:
+ results.append(TaskResult(
+ task_id=i,
+ success=False,
+ error="Worker returned None"
+ ))
+ self.progress_manager.update(success=False)
+ self.progress_manager.display()
+
+ self.progress_manager.finish()
+ return results
+
+ def _run_slurm_direct(
+ self,
+ tasks: List[Any],
+ worker_func: Callable,
+ desc: str
+ ) -> List[TaskResult]:
+ """
+ SLURM 直接提交模式
+
+ 注意:对于数据库分析这类快速任务,建议使用本地多进程模式
+ SLURM 模式更适合耗时的计算任务(如 Zeo++ 分析)
+
+ 这里回退到本地模式,因为 srun 在登录节点直接调用效率不高
+ """
+ print("\n⚠️ 注意:数据库分析阶段自动使用本地多进程模式")
+ print(" SLURM 模式将在后续耗时计算步骤中使用")
+
+ # 回退到本地模式
+ return self._run_local(tasks, worker_func, desc)
+
+
+class SlurmJobManager:
+ """
+ SLURM 作业管理器
+
+ 用于批量提交和监控 SLURM 作业
+ """
+
+ def __init__(self, config: ExecutorConfig):
+ self.config = config
+ self.active_jobs = {} # job_id -> task_info
+
+ def submit_batch(
+ self,
+ tasks: List[Tuple[str, str, set]], # (file_path, target_cation, target_anions)
+ output_dir: str,
+ desc: str = "Processing"
+ ) -> List[TaskResult]:
+ """
+ 批量提交任务到 SLURM
+
+ 使用 sbatch --wrap 直接提交,不生成脚本文件
+ """
+ total = len(tasks)
+ os.makedirs(output_dir, exist_ok=True)
+
+ print(f"\n{'='*60}")
+ print(f"SLURM 批量提交:")
+ print(f" 总任务数: {total}")
+ print(f" 输出目录: {output_dir}")
+ print(f" Conda环境: {self.config.conda_env}")
+ print(f"{'='*60}\n")
+
+ progress = ProgressManager(total, desc)
+ progress.start()
+
+ results = []
+ job_ids = []
+
+ # 提交所有任务
+ for i, task in enumerate(tasks):
+ file_path, target_cation, target_anions = task
+
+ # 构建 Python 命令
+ anions_str = ','.join(target_anions)
+ python_cmd = (
+ f"python -c \""
+ f"import sys; sys.path.insert(0, '{os.getcwd()}'); "
+ f"from src.analysis.worker import analyze_single_file; "
+ f"result = analyze_single_file(('{file_path}', '{target_cation}', set('{anions_str}'.split(',')))); "
+ f"print('SUCCESS' if result and result.is_valid else 'FAILED')"
+ f"\""
+ )
+
+ # 构建完整的 bash 命令
+ bash_cmd = (
+ f"source {os.path.dirname(self.config.conda_env)}/../../etc/profile.d/conda.sh && "
+ f"conda activate {self.config.conda_env} && "
+ f"{python_cmd}"
+ )
+
+ # 使用 sbatch --wrap 提交
+ sbatch_cmd = [
+ 'sbatch',
+ '--partition', self.config.partition,
+ '--ntasks', '1',
+ '--cpus-per-task', '1',
+ '--mem', self.config.memory_per_task,
+ '--time', '01:00:00',
+ '--output', os.path.join(output_dir, f'task_{i}.out'),
+ '--error', os.path.join(output_dir, f'task_{i}.err'),
+ '--wrap', bash_cmd
+ ]
+
+ try:
+ result = subprocess.run(
+ sbatch_cmd,
+ capture_output=True,
+ text=True
+ )
+
+ if result.returncode == 0:
+ # 提取 job_id
+ job_id = result.stdout.strip().split()[-1]
+ job_ids.append((i, job_id, file_path))
+ self.active_jobs[job_id] = {
+ 'task_index': i,
+ 'file_path': file_path,
+ 'status': 'PENDING'
+ }
+ else:
+ results.append(TaskResult(
+ task_id=i,
+ success=False,
+ error=f"提交失败: {result.stderr}"
+ ))
+ progress.update(success=False)
+ progress.display()
+
+ except Exception as e:
+ results.append(TaskResult(
+ task_id=i,
+ success=False,
+ error=str(e)
+ ))
+ progress.update(success=False)
+ progress.display()
+
+ print(f"\n已提交 {len(job_ids)} 个作业,等待完成...")
+
+ # 监控作业状态
+ while self.active_jobs:
+ time.sleep(self.config.poll_interval)
+
+ # 检查作业状态
+ completed_jobs = self._check_job_status()
+
+ for job_id, status in completed_jobs:
+ job_info = self.active_jobs.pop(job_id, None)
+ if job_info:
+ task_idx = job_info['task_index']
+
+ if status == 'COMPLETED':
+ # 检查输出文件
+ out_file = os.path.join(output_dir, f'task_{task_idx}.out')
+ success = False
+ if os.path.exists(out_file):
+ with open(out_file, 'r') as f:
+ content = f.read()
+ success = 'SUCCESS' in content
+
+ results.append(TaskResult(
+ task_id=task_idx,
+ success=success,
+ result=job_info['file_path']
+ ))
+ progress.update(success=success)
+ else:
+ # 作业失败
+ err_file = os.path.join(output_dir, f'task_{task_idx}.err')
+ error_msg = status
+ if os.path.exists(err_file):
+ with open(err_file, 'r') as f:
+ error_msg = f.read()[:500] # 只取前500字符
+
+ results.append(TaskResult(
+ task_id=task_idx,
+ success=False,
+ error=error_msg
+ ))
+ progress.update(success=False)
+
+ progress.display()
+
+ progress.finish()
+ return results
+
+ def _check_job_status(self) -> List[Tuple[str, str]]:
+ """检查作业状态,返回已完成的作业列表"""
+ if not self.active_jobs:
+ return []
+
+ job_ids = list(self.active_jobs.keys())
+
+ try:
+ result = subprocess.run(
+ ['sacct', '-j', ','.join(job_ids), '--format=JobID,State', '--noheader', '--parsable2'],
+ capture_output=True,
+ text=True,
+ timeout=30
+ )
+
+ completed = []
+ if result.returncode == 0:
+ for line in result.stdout.strip().split('\n'):
+ if line:
+ parts = line.split('|')
+ if len(parts) >= 2:
+ job_id = parts[0].split('.')[0] # 去掉 .batch 后缀
+ status = parts[1]
+
+ if job_id in self.active_jobs:
+ if status in ['COMPLETED', 'FAILED', 'CANCELLED', 'TIMEOUT', 'NODE_FAIL']:
+ completed.append((job_id, status))
+
+ return completed
+
+ except Exception:
+ return []
+
+
+def create_executor(
+ mode: str = "local",
+ max_workers: int = None,
+ conda_env: str = None,
+ **kwargs
+) -> TaskExecutor:
+ """
+ 创建任务执行器的便捷函数
+
+ Args:
+ mode: "local" 或 "slurm"
+ max_workers: 最大工作进程数
+ conda_env: Conda 环境路径
+ **kwargs: 其他配置参数
+ """
+ env = TaskExecutor.detect_environment()
+
+ if max_workers is None:
+ max_workers = min(env['total_cores'], 32)
+
+ if conda_env is None:
+ conda_env = env.get('conda_env') or "/cluster/home/koko125/anaconda3/envs/screen"
+
+ exec_mode = ExecutionMode.SLURM_DIRECT if mode.lower() == "slurm" else ExecutionMode.LOCAL
+
+ config = ExecutorConfig(
+ mode=exec_mode,
+ max_workers=max_workers,
+ conda_env=conda_env,
+ **kwargs
+ )
+
+ return TaskExecutor(config)
diff --git a/src/preprocessing/processor.py b/src/preprocessing/processor.py
new file mode 100644
index 0000000..06c6709
--- /dev/null
+++ b/src/preprocessing/processor.py
@@ -0,0 +1,562 @@
+"""
+结构预处理器:扩胞和添加化合价
+"""
+import os
+import re
+import yaml
+from typing import List, Dict, Optional, Tuple
+from dataclasses import dataclass, field
+from pymatgen.core.structure import Structure
+from pymatgen.core.periodic_table import Specie
+from pymatgen.core import Lattice, Species, PeriodicSite
+from collections import defaultdict
+from fractions import Fraction
+from functools import reduce
+import math
+import random
+import spglib
+import numpy as np
+
+
+@dataclass
+class ProcessingResult:
+ """处理结果"""
+ input_file: str
+ output_files: List[str] = field(default_factory=list)
+ success: bool = False
+ needs_expansion: bool = False
+ expansion_factor: int = 1
+ error_message: str = ""
+
+
+class StructureProcessor:
+ """结构预处理器"""
+
+ # 默认化合价配置
+ DEFAULT_VALENCE_PATH = os.path.join(
+ os.path.dirname(__file__), '..', '..', 'tool', 'valence_states.yaml'
+ )
+
+ def __init__(
+ self,
+ valence_yaml_path: str = None,
+ calculate_type: str = 'low',
+ max_expansion_factor: int = 64,
+ keep_number: int = 3,
+ target_cation: str = "Li"
+ ):
+ """
+ 初始化处理器
+
+ Args:
+ valence_yaml_path: 化合价配置文件路径
+ calculate_type: 扩胞计算精度 ('high', 'normal', 'low', 'very_low')
+ max_expansion_factor: 最大扩胞因子
+ keep_number: 保留的扩胞结构数量
+ target_cation: 目标阳离子
+ """
+ self.valence_yaml_path = valence_yaml_path or self.DEFAULT_VALENCE_PATH
+ self.calculate_type = calculate_type
+ self.max_expansion_factor = max_expansion_factor
+ self.keep_number = keep_number
+ self.target_cation = target_cation
+ self.explict_element = [target_cation, f"{target_cation}+"]
+
+ # 加载化合价配置
+ self.valences = self._load_valences()
+
+ def _load_valences(self) -> Dict[str, int]:
+ """加载化合价配置"""
+ if os.path.exists(self.valence_yaml_path):
+ with open(self.valence_yaml_path, 'r') as f:
+ return yaml.safe_load(f)
+ return {}
+
+ def process_file(
+ self,
+ input_path: str,
+ output_dir: str,
+ needs_expansion: bool = False
+ ) -> ProcessingResult:
+ """
+ 处理单个CIF文件
+
+ Args:
+ input_path: 输入文件路径
+ output_dir: 输出目录
+ needs_expansion: 是否需要扩胞
+
+ Returns:
+ ProcessingResult: 处理结果
+ """
+ result = ProcessingResult(input_file=input_path)
+
+ try:
+ # 读取结构
+ structure = Structure.from_file(input_path)
+ base_name = os.path.splitext(os.path.basename(input_path))[0]
+
+ # 检查是否需要扩胞
+ occupation_list = self._process_cif_file(structure)
+
+ if occupation_list and needs_expansion:
+ # 需要扩胞处理
+ result.needs_expansion = True
+ output_files = self._expand_and_save(
+ structure, occupation_list, base_name, output_dir
+ )
+ result.output_files = output_files
+ result.expansion_factor = occupation_list[0].get('denominator', 1) if occupation_list else 1
+ else:
+ # 不需要扩胞,直接添加化合价
+ output_path = os.path.join(output_dir, f"{base_name}.cif")
+ self._add_oxidation_states(structure)
+ structure.to(filename=output_path)
+ result.output_files = [output_path]
+
+ result.success = True
+
+ except Exception as e:
+ result.success = False
+ result.error_message = str(e)
+
+ return result
+
+ def _process_cif_file(self, structure: Structure) -> List[Dict]:
+ """
+ 统计结构中各原子的occupation情况
+ """
+ occupation_dict = defaultdict(list)
+ split_dict = {}
+
+ for i, site in enumerate(structure):
+ occu = self._get_occu(site.species_string)
+
+ if occu != 1.0:
+ if site.species.chemical_system not in self.explict_element:
+ occupation_dict[occu].append(i + 1)
+
+ # 提取元素名称列表
+ elements = []
+ if ':' in site.species_string:
+ parts = site.species_string.split(',')
+ for part in parts:
+ element_with_valence = part.strip().split(':')[0].strip()
+ element_match = re.match(r'([A-Z][a-z]?)', element_with_valence)
+ if element_match:
+ elements.append(element_match.group(1))
+ else:
+ element_match = re.match(r'([A-Z][a-z]?)', site.species_string)
+ if element_match:
+ elements = [element_match.group(1)]
+
+ split_dict[occu] = elements
+
+ # 转换为列表格式
+ occupation_list = [
+ {
+ "occupation": occu,
+ "atom_serial": serials,
+ "numerator": None,
+ "denominator": None,
+ "split": split_dict.get(occu, [])
+ }
+ for occu, serials in occupation_dict.items()
+ ]
+
+ return occupation_list
+
+ def _get_occu(self, s_str: str) -> float:
+ """从物种字符串获取占据率"""
+ if not s_str.strip():
+ return 1.0
+
+ pattern = r'([A-Za-z0-9+-]+):([0-9.]+)'
+ matches = re.findall(pattern, s_str)
+
+ for species, occu in matches:
+ if species not in self.explict_element:
+ try:
+ return float(occu)
+ except ValueError:
+ continue
+
+ return 1.0
+
+ def _calculate_expansion_factor(self, occupation_list: List[Dict]) -> Tuple[int, List[Dict]]:
+ """计算扩胞因子"""
+ if not occupation_list:
+ return 1, []
+
+ precision_limits = {
+ 'high': None,
+ 'normal': 100,
+ 'low': 10,
+ 'very_low': 5
+ }
+
+ limit = precision_limits.get(self.calculate_type)
+
+ for entry in occupation_list:
+ occu = entry["occupation"]
+ if limit:
+ fraction = Fraction(occu).limit_denominator(limit)
+ else:
+ fraction = Fraction(occu).limit_denominator()
+
+ entry["numerator"] = fraction.numerator
+ entry["denominator"] = fraction.denominator
+
+ # 计算最小公倍数
+ denominators = [entry["denominator"] for entry in occupation_list]
+ lcm = reduce(lambda a, b: a * b // math.gcd(a, b), denominators, 1)
+
+ # 统一分母
+ for entry in occupation_list:
+ denominator = entry["denominator"]
+ entry["numerator"] = entry["numerator"] * (lcm // denominator)
+ entry["denominator"] = lcm
+
+ return lcm, occupation_list
+
+ def _expand_and_save(
+ self,
+ structure: Structure,
+ occupation_list: List[Dict],
+ base_name: str,
+ output_dir: str
+ ) -> List[str]:
+ """扩胞并保存"""
+ lcm, oc_list = self._calculate_expansion_factor(occupation_list)
+
+ if lcm > self.max_expansion_factor:
+ raise ValueError(f"扩胞因子 {lcm} 超过最大限制 {self.max_expansion_factor}")
+
+ # 获取扩胞策略
+ strategies = self._strategy_divide(structure, lcm)
+
+ if not strategies:
+ raise ValueError("无法找到合适的扩胞策略")
+
+ # 生成结构列表
+ st_list = self._generate_structure_list(structure, oc_list)
+
+ output_files = []
+ keep_number = min(self.keep_number, len(strategies))
+
+ for index in range(keep_number):
+ merged = self._merge_structures(st_list, strategies[index])
+
+ # 添加化合价
+ self._add_oxidation_states(merged)
+
+ # 当只保存1个时不加后缀
+ if keep_number == 1:
+ output_filename = f"{base_name}.cif"
+ else:
+ suffix = "x{}y{}z{}".format(
+ strategies[index]["x"],
+ strategies[index]["y"],
+ strategies[index]["z"]
+ )
+ output_filename = f"{base_name}-{suffix}.cif"
+
+ output_path = os.path.join(output_dir, output_filename)
+ merged.to(filename=output_path, fmt="cif")
+ output_files.append(output_path)
+
+ return output_files
+
+ def _add_oxidation_states(self, structure: Structure):
+ """添加化合价"""
+ # 检查是否已有化合价
+ has_oxidation = all(
+ all(isinstance(sp, Specie) for sp in site.species.keys())
+ for site in structure.sites
+ )
+
+ if not has_oxidation and self.valences:
+ structure.add_oxidation_state_by_element(self.valences)
+
+ def _strategy_divide(self, structure: Structure, total: int) -> List[Dict]:
+ """根据晶体类型确定扩胞策略"""
+ try:
+ space_group_info = structure.get_space_group_info()
+ space_group_symbol = space_group_info[0]
+
+ # 获取空间群类型
+ all_spacegroup_symbols = [spglib.get_spacegroup_type(i) for i in range(1, 531)]
+ symbol = all_spacegroup_symbols[0]
+ for symbol_i in all_spacegroup_symbols:
+ if space_group_symbol == symbol_i.international_short:
+ symbol = symbol_i
+ break
+
+ space_type = self._typejudge(symbol.number)
+
+ if space_type == "Cubic":
+ return self._factorize_to_three_factors(total, "xyz")
+ else:
+ return self._factorize_to_three_factors(total)
+ except:
+ return self._factorize_to_three_factors(total)
+
+ def _typejudge(self, number: int) -> str:
+ """判断晶体类型"""
+ if number in [1, 2]:
+ return "Triclinic"
+ elif 3 <= number <= 15:
+ return "Monoclinic"
+ elif 16 <= number <= 74:
+ return "Orthorhombic"
+ elif 75 <= number <= 142:
+ return "Tetragonal"
+ elif 143 <= number <= 167:
+ return "Trigonal"
+ elif 168 <= number <= 194:
+ return "Hexagonal"
+ elif 195 <= number <= 230:
+ return "Cubic"
+ else:
+ return "Unknown"
+
+ def _factorize_to_three_factors(self, n: int, type_sym: str = None) -> List[Dict]:
+ """分解为三个因子"""
+ factors = []
+
+ if type_sym == "xyz":
+ for x in range(1, n + 1):
+ if n % x == 0:
+ remaining_n = n // x
+ for y in range(1, remaining_n + 1):
+ if remaining_n % y == 0 and y <= x:
+ z = remaining_n // y
+ if z <= y:
+ factors.append({'x': x, 'y': y, 'z': z})
+ else:
+ for x in range(1, n + 1):
+ if n % x == 0:
+ remaining_n = n // x
+ for y in range(1, remaining_n + 1):
+ if remaining_n % y == 0:
+ z = remaining_n // y
+ factors.append({'x': x, 'y': y, 'z': z})
+
+ # 排序
+ def sort_key(item):
+ return (item['x'] + item['y'] + item['z'], item['z'], item['y'], item['x'])
+
+ return sorted(factors, key=sort_key)
+
+ def _generate_structure_list(
+ self,
+ base_structure: Structure,
+ occupation_list: List[Dict]
+ ) -> List[Structure]:
+ """生成结构列表"""
+ if not occupation_list:
+ return [base_structure.copy()]
+
+ lcm = occupation_list[0]["denominator"]
+ structure_list = [base_structure.copy() for _ in range(lcm)]
+
+ for entry in occupation_list:
+ numerator = entry["numerator"]
+ denominator = entry["denominator"]
+ atom_indices = entry["atom_serial"]
+
+ for atom_idx in atom_indices:
+ occupancy_dict = self._mark_atoms_randomly(numerator, denominator)
+ original_site = base_structure.sites[atom_idx - 1]
+ element = self._get_first_non_explicit_element(original_site.species_string)
+
+ for copy_idx, occupy in occupancy_dict.items():
+ structure_list[copy_idx].remove_sites([atom_idx - 1])
+ oxi_state = self._extract_oxi_state(original_site.species_string, element)
+
+ if len(entry["split"]) == 1:
+ if occupy:
+ new_site = PeriodicSite(
+ species=Species(element, oxi_state),
+ coords=original_site.frac_coords,
+ lattice=structure_list[copy_idx].lattice,
+ to_unit_cell=True,
+ label=original_site.label
+ )
+ else:
+ species_dict = {Species(self.target_cation, 1.0): 0.0}
+ new_site = PeriodicSite(
+ species=species_dict,
+ coords=original_site.frac_coords,
+ lattice=structure_list[copy_idx].lattice,
+ to_unit_cell=True,
+ label=original_site.label
+ )
+ else:
+ if occupy:
+ new_site = PeriodicSite(
+ species=Species(element, oxi_state),
+ coords=original_site.frac_coords,
+ lattice=structure_list[copy_idx].lattice,
+ to_unit_cell=True,
+ label=original_site.label
+ )
+ else:
+ new_site = PeriodicSite(
+ species=Species(entry['split'][1], oxi_state),
+ coords=original_site.frac_coords,
+ lattice=structure_list[copy_idx].lattice,
+ to_unit_cell=True,
+ label=original_site.label
+ )
+
+ structure_list[copy_idx].sites.insert(atom_idx - 1, new_site)
+
+ return structure_list
+
+ def _mark_atoms_randomly(self, numerator: int, denominator: int) -> Dict[int, int]:
+ """随机标记原子"""
+ if numerator > denominator:
+ raise ValueError(f"numerator ({numerator}) 不能超过 denominator ({denominator})")
+
+ atom_dice = list(range(denominator))
+ selected_atoms = random.sample(atom_dice, numerator)
+
+ return {atom: 1 if atom in selected_atoms else 0 for atom in atom_dice}
+
+ def _get_first_non_explicit_element(self, species_str: str) -> str:
+ """获取第一个非目标元素"""
+ if not species_str.strip():
+ return ""
+
+ species_parts = [part.strip() for part in species_str.split(",") if part.strip()]
+
+ for part in species_parts:
+ element_with_charge = part.split(":")[0].strip()
+ pure_element = ''.join([c for c in element_with_charge if c.isalpha()])
+
+ if pure_element not in self.explict_element:
+ return pure_element
+
+ return ""
+
+ def _extract_oxi_state(self, species_str: str, element: str) -> int:
+ """提取氧化态"""
+ species_parts = [part.strip() for part in species_str.split(",") if part.strip()]
+
+ for part in species_parts:
+ element_with_charge = part.split(":")[0].strip()
+
+ if element in element_with_charge:
+ charge_part = element_with_charge[len(element):]
+
+ if not any(c.isdigit() for c in charge_part):
+ if "+" in charge_part:
+ return 1
+ elif "-" in charge_part:
+ return -1
+ else:
+ return 0
+
+ sign = 1
+ if "-" in charge_part:
+ sign = -1
+
+ digits = ""
+ for c in charge_part:
+ if c.isdigit():
+ digits += c
+
+ if digits:
+ return sign * int(digits)
+
+ return 0
+
+ def _merge_structures(self, structure_list: List[Structure], merge_dict: Dict) -> Structure:
+ """合并结构"""
+ if not structure_list:
+ raise ValueError("结构列表不能为空")
+
+ ref_lattice = structure_list[0].lattice
+
+ total_merge = merge_dict.get("x", 1) * merge_dict.get("y", 1) * merge_dict.get("z", 1)
+ if len(structure_list) != total_merge:
+ raise ValueError(f"结构数量({len(structure_list)})与合并次数({total_merge})不匹配")
+
+ a, b, c = ref_lattice.abc
+ alpha, beta, gamma = ref_lattice.angles
+
+ new_a = a * merge_dict.get("x", 1)
+ new_b = b * merge_dict.get("y", 1)
+ new_c = c * merge_dict.get("z", 1)
+ new_lattice = Lattice.from_parameters(new_a, new_b, new_c, alpha, beta, gamma)
+
+ all_sites = []
+ for i, structure in enumerate(structure_list):
+ x_offset = (i // (merge_dict.get("y", 1) * merge_dict.get("z", 1))) % merge_dict.get("x", 1)
+ y_offset = (i // merge_dict.get("z", 1)) % merge_dict.get("y", 1)
+ z_offset = i % merge_dict.get("z", 1)
+
+ for site in structure:
+ coords = site.frac_coords.copy()
+ coords[0] = (coords[0] + x_offset) / merge_dict.get("x", 1)
+ coords[1] = (coords[1] + y_offset) / merge_dict.get("y", 1)
+ coords[2] = (coords[2] + z_offset) / merge_dict.get("z", 1)
+ all_sites.append({"species": site.species, "coords": coords})
+
+ return Structure(
+ new_lattice,
+ [site["species"] for site in all_sites],
+ [site["coords"] for site in all_sites]
+ )
+
+
+def process_batch(
+ input_files: List[str],
+ output_dir: str,
+ needs_expansion_flags: List[bool] = None,
+ valence_yaml_path: str = None,
+ calculate_type: str = 'low',
+ target_cation: str = "Li",
+ show_progress: bool = True
+) -> List[ProcessingResult]:
+ """
+ 批量处理CIF文件
+
+ Args:
+ input_files: 输入文件列表
+ output_dir: 输出目录
+ needs_expansion_flags: 是否需要扩胞的标记列表
+ valence_yaml_path: 化合价配置文件路径
+ calculate_type: 扩胞计算精度
+ target_cation: 目标阳离子
+ show_progress: 是否显示进度
+
+ Returns:
+ 处理结果列表
+ """
+ os.makedirs(output_dir, exist_ok=True)
+
+ processor = StructureProcessor(
+ valence_yaml_path=valence_yaml_path,
+ calculate_type=calculate_type,
+ target_cation=target_cation
+ )
+
+ if needs_expansion_flags is None:
+ needs_expansion_flags = [False] * len(input_files)
+
+ results = []
+ total = len(input_files)
+
+ for i, (input_file, needs_exp) in enumerate(zip(input_files, needs_expansion_flags)):
+ if show_progress:
+ print(f"\r处理进度: {i+1}/{total} - {os.path.basename(input_file)}", end="")
+
+ result = processor.process_file(input_file, output_dir, needs_exp)
+ results.append(result)
+
+ if show_progress:
+ print()
+
+ return results
diff --git a/tool/Li/Br+O.yaml b/tool/Li/Br+O.yaml
new file mode 100644
index 0000000..02f3d68
--- /dev/null
+++ b/tool/Li/Br+O.yaml
@@ -0,0 +1,5 @@
+SPECIE: Li+
+ANION: Br
+PERCO_R: 0.45
+NEIGHBOR: 1.8
+LONG: 2.2
diff --git a/tool/Li/Cl+O.yaml b/tool/Li/Cl+O.yaml
new file mode 100644
index 0000000..a7cb541
--- /dev/null
+++ b/tool/Li/Cl+O.yaml
@@ -0,0 +1,5 @@
+SPECIE: Li+
+ANION: Cl
+PERCO_R: 0.45
+NEIGHBOR: 1.8
+LONG: 2.2