From f78298e80312f0329cc866ba64c06d4b74ada6b5 Mon Sep 17 00:00:00 2001 From: koko Date: Tue, 16 Dec 2025 11:36:49 +0800 Subject: [PATCH] =?UTF-8?q?=E4=B8=80=E9=98=B6=E6=AE=B5=E9=AB=98=E7=AD=9B?= =?UTF-8?q?=E5=88=B6=E4=BD=9C=E5=AE=8C=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 21 + .idea/.gitignore | 8 - .idea/Screen.iml | 8 - .idea/inspectionProfiles/Project_Default.xml | 23 - .../inspectionProfiles/profiles_settings.xml | 6 - .idea/modules.xml | 8 - .idea/vcs.xml | 6 - main.py | 627 +++++++++++++++--- readme.md | 349 +++++++++- src/__init__.py | 12 + src/analysis/database_analyzer.py | 191 ++++-- src/analysis/report_generator.py | 4 +- src/analysis/structure_inspector.py | 40 -- src/computation/__init__.py | 15 + src/computation/result_processor.py | 426 ++++++++++++ src/computation/workspace_manager.py | 288 ++++++++ src/computation/zeo_executor.py | 446 +++++++++++++ src/core/__init__.py | 18 + src/core/executor.py | 431 ++++++++++++ src/preprocessing/processor.py | 562 ++++++++++++++++ tool/Li/Br+O.yaml | 5 + tool/Li/Cl+O.yaml | 5 + 22 files changed, 3276 insertions(+), 223 deletions(-) create mode 100644 .gitignore delete mode 100644 .idea/.gitignore delete mode 100644 .idea/Screen.iml delete mode 100644 .idea/inspectionProfiles/Project_Default.xml delete mode 100644 .idea/inspectionProfiles/profiles_settings.xml delete mode 100644 .idea/modules.xml delete mode 100644 .idea/vcs.xml create mode 100644 src/computation/__init__.py create mode 100644 src/computation/result_processor.py create mode 100644 src/computation/workspace_manager.py create mode 100644 src/computation/zeo_executor.py create mode 100644 src/core/executor.py create mode 100644 src/preprocessing/processor.py create mode 100644 tool/Li/Br+O.yaml create mode 100644 tool/Li/Cl+O.yaml diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..5f987aa --- /dev/null +++ b/.gitignore @@ -0,0 +1,21 @@ +# --- Python 通用忽略 --- +__pycache__/ +*.pyc +*.pyo +*.pyd +.Python +env/ +venv/ +.env +.venv/ + +# --- VS Code 配置 (可选,建议忽略) --- +.vscode/ + +# --- JetBrains (你之前的配置) --- +.idea/ +/shelf/ +/workspace.xml +/httpRequests/ +/dataSources/ +/dataSources.local.xml \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore deleted file mode 100644 index 35410ca..0000000 --- a/.idea/.gitignore +++ /dev/null @@ -1,8 +0,0 @@ -# 默认忽略的文件 -/shelf/ -/workspace.xml -# 基于编辑器的 HTTP 客户端请求 -/httpRequests/ -# Datasource local storage ignored files -/dataSources/ -/dataSources.local.xml diff --git a/.idea/Screen.iml b/.idea/Screen.iml deleted file mode 100644 index 909438d..0000000 --- a/.idea/Screen.iml +++ /dev/null @@ -1,8 +0,0 @@ - - - - - - - - \ No newline at end of file diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml deleted file mode 100644 index bebfaed..0000000 --- a/.idea/inspectionProfiles/Project_Default.xml +++ /dev/null @@ -1,23 +0,0 @@ - - - - \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml deleted file mode 100644 index 105ce2d..0000000 --- a/.idea/inspectionProfiles/profiles_settings.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml deleted file mode 100644 index 287dd8e..0000000 --- a/.idea/modules.xml +++ /dev/null @@ -1,8 +0,0 @@ - - - - - - - - \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml deleted file mode 100644 index 94a25f7..0000000 --- a/.idea/vcs.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/main.py b/main.py index 4fce51d..9379b05 100644 --- a/main.py +++ b/main.py @@ -1,45 +1,188 @@ """ -高通量筛选与扩胞项目 - 主入口(支持并行) +高通量筛选与扩胞项目 - 主入口(支持断点续做) """ import os import sys +import json sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src')) from src.analysis.database_analyzer import DatabaseAnalyzer from src.analysis.report_generator import ReportGenerator -from src.core.scheduler import ParallelScheduler +from src.core.executor import TaskExecutor +from src.preprocessing.processor import StructureProcessor +from src.computation.workspace_manager import WorkspaceManager +from src.computation.zeo_executor import ZeoExecutor, ZeoConfig +from src.computation.result_processor import ResultProcessor, FilterCriteria def print_banner(): print(""" ╔═══════════════════════════════════════════════════════════════════╗ -║ 高通量筛选与扩胞项目 - 数据库分析工具 v2.0 ║ -║ 支持高性能并行计算 ║ +║ 高通量筛选与扩胞项目 - 数据库分析工具 v2.2 ║ +║ 支持断点续做与高性能并行计算 ║ ╚═══════════════════════════════════════════════════════════════════╝ """) def detect_and_show_environment(): """检测并显示环境信息""" - env = ParallelScheduler.detect_environment() + env = TaskExecutor.detect_environment() print("【运行环境检测】") print(f" 主机名: {env['hostname']}") print(f" 本地CPU核数: {env['total_cores']}") print(f" SLURM集群: {'✅ 可用' if env['has_slurm'] else '❌ 不可用'}") - + if env['has_slurm'] and env['slurm_partitions']: - print(f" 可用分区:") - for p in env['slurm_partitions']: - print(f" - {p['name']}: {p['nodes']}节点, {p['cores_per_node']}核/节点") + print(f" 可用分区: {', '.join(env['slurm_partitions'])}") + + if env['conda_env']: + print(f" 当前Conda: {env['conda_env']}") return env -def get_user_input(env: dict): - """获取用户输入""" +def detect_workflow_status(workspace_path: str = "workspace", target_cation: str = "Li") -> dict: + """ + 检测工作流程状态,确定可以从哪一步继续 + + Returns: + dict: { + 'has_processed_data': bool, # 是否有扩胞处理后的数据 + 'has_zeo_results': bool, # 是否有 Zeo++ 计算结果 + 'total_structures': int, # 总结构数 + 'structures_with_log': int, # 有 log.txt 的结构数 + 'workspace_info': object, # 工作区信息 + 'ws_manager': object # 工作区管理器 + } + """ + status = { + 'has_processed_data': False, + 'has_zeo_results': False, + 'total_structures': 0, + 'structures_with_log': 0, + 'workspace_info': None, + 'ws_manager': None + } + + # 创建工作区管理器 + ws_manager = WorkspaceManager( + workspace_path=workspace_path, + tool_dir="tool", + target_cation=target_cation + ) + status['ws_manager'] = ws_manager + + # 检查是否有处理后的数据 + data_dir = os.path.join(workspace_path, "data") + if os.path.exists(data_dir): + existing = ws_manager.check_existing_workspace() + if existing and existing.total_structures > 0: + status['has_processed_data'] = True + status['total_structures'] = existing.total_structures + status['workspace_info'] = existing + + # 检查有多少结构有 log.txt(即已完成 Zeo++ 计算) + log_count = 0 + for anion_key in os.listdir(data_dir): + anion_dir = os.path.join(data_dir, anion_key) + if not os.path.isdir(anion_dir): + continue + for struct_name in os.listdir(anion_dir): + struct_dir = os.path.join(anion_dir, struct_name) + if os.path.isdir(struct_dir): + log_path = os.path.join(struct_dir, "log.txt") + if os.path.exists(log_path): + log_count += 1 + + status['structures_with_log'] = log_count + + # 如果大部分结构都有 log.txt,认为 Zeo++ 计算已完成 + if log_count > 0 and log_count >= status['total_structures'] * 0.5: + status['has_zeo_results'] = True + + return status + +def print_workflow_status(status: dict): + """打印工作流程状态""" + print("\n" + "─" * 50) + print("【工作流程状态检测】") + print("─" * 50) + + if not status['has_processed_data']: + print(" Step 1 (扩胞+化合价): ❌ 未完成") + print(" Step 2-4 (Zeo++ 计算): ❌ 未完成") + print(" Step 5 (结果筛选): ❌ 未完成") + else: + print(f" Step 1 (扩胞+化合价): ✅ 已完成 ({status['total_structures']} 个结构)") + + if status['has_zeo_results']: + print(f" Step 2-4 (Zeo++ 计算): ✅ 已完成 ({status['structures_with_log']}/{status['total_structures']} 有日志)") + print(" Step 5 (结果筛选): ⏳ 可执行") + else: + print(f" Step 2-4 (Zeo++ 计算): ⏳ 可执行 ({status['structures_with_log']}/{status['total_structures']} 有日志)") + print(" Step 5 (结果筛选): ❌ 需先完成 Zeo++ 计算") + + print("─" * 50) + + +def get_user_choice(status: dict) -> str: + """ + 根据工作流程状态获取用户选择 + + Returns: + 'step1': 从头开始(数据库分析 + 扩胞) + 'step2': 从 Zeo++ 计算开始 + 'step5': 直接进行结果筛选 + 'exit': 退出 + """ + print("\n请选择操作:") + + options = [] + + if status['has_zeo_results']: + options.append(('5', '直接进行结果筛选 (Step 5)')) + options.append(('2', '重新运行 Zeo++ 计算 (Step 2-4)')) + options.append(('1', '从头开始 (数据库分析 + 扩胞)')) + elif status['has_processed_data']: + options.append(('2', '运行 Zeo++ 计算 (Step 2-4)')) + options.append(('1', '从头开始 (数据库分析 + 扩胞)')) + else: + options.append(('1', '从头开始 (数据库分析 + 扩胞)')) + + options.append(('q', '退出')) + + for key, desc in options: + print(f" {key}. {desc}") + + choice = input("\n请选择 [默认: " + options[0][0] + "]: ").strip().lower() + + if not choice: + choice = options[0][0] + + if choice == 'q': + return 'exit' + elif choice == '5': + return 'step5' + elif choice == '2': + return 'step2' + else: + return 'step1' + + +def run_step1_database_analysis(env: dict, cation: str) -> dict: + """ + Step 1: 数据库分析与扩胞处理 + + Returns: + 处理参数字典,如果用户取消则返回 None + """ + print("\n" + "═" * 60) + print("【Step 1: 数据库分析与扩胞处理】") + print("═" * 60) + # 数据库路径 while True: db_path = input("\n📂 请输入数据库路径: ").strip() @@ -48,9 +191,8 @@ def get_user_input(env: dict): print(f"❌ 路径不存在: {db_path}") # 检测当前Conda环境路径 - conda_env_path = env.get('conda_prefix', '') - if not conda_env_path: - conda_env_path = os.path.expanduser("~/anaconda3/envs/screen") + default_conda = "/cluster/home/koko125/anaconda3/envs/screen" + conda_env_path = env.get('conda_env', '') or default_conda print(f"\n检测到Conda环境: {conda_env_path}") custom_env = input(f"使用此环境? [Y/n] 或输入其他路径: ").strip() @@ -59,9 +201,6 @@ def get_user_input(env: dict): elif custom_env and custom_env.lower() != 'y': conda_env_path = custom_env - # 目标阳离子 - cation = input("🎯 请输入目标阳离子 [默认: Li]: ").strip() or "Li" - # 目标阴离子 anion_input = input("🎯 请输入目标阴离子 (逗号分隔) [默认: O,S,Cl,Br]: ").strip() anions = set(a.strip() for a in anion_input.split(',')) if anion_input else {'O', 'S', 'Cl', 'Br'} @@ -79,43 +218,18 @@ def get_user_input(env: dict): print("【并行计算配置】") default_cores = min(env['total_cores'], 32) - cores_input = input(f"💻 最大可用核数 [默认: {default_cores}]: ").strip() - max_cores = int(cores_input) if cores_input.isdigit() else default_cores + cores_input = input(f"💻 最大可用核数/Worker数 [默认: {default_cores}]: ").strip() + max_workers = int(cores_input) if cores_input.isdigit() else default_cores - print("\n任务复杂度 (影响每个Worker分配的核数):") - print(" 1. 低 (1核/Worker) - 简单IO操作") - print(" 2. 中 (2核/Worker) - 结构解析+检查 [默认]") - print(" 3. 高 (4核/Worker) - 复杂计算") - complexity_choice = input("请选择 [1/2/3]: ").strip() - complexity = {'1': 'low', '2': 'medium', '3': 'high', '': 'medium'}.get(complexity_choice, 'medium') - - # 执行模式 - use_slurm = False - if env['has_slurm']: - slurm_choice = input("\n是否使用SLURM提交作业? [y/N]: ").strip().lower() - use_slurm = slurm_choice == 'y' - - return { + params = { 'database_path': db_path, 'target_cation': cation, 'target_anions': anions, 'anion_mode': anion_mode, - 'max_cores': max_cores, - 'task_complexity': complexity, - 'use_slurm': use_slurm + 'max_workers': max_workers, + 'conda_env': conda_env_path, } - -def main(): - """主函数""" - print_banner() - - # 环境检测 - env = detect_and_show_environment() - - # 获取用户输入 - params = get_user_input(env) - print("\n" + "═" * 60) print("开始数据库分析...") print("═" * 60) @@ -126,43 +240,404 @@ def main(): target_cation=params['target_cation'], target_anions=params['target_anions'], anion_mode=params['anion_mode'], - max_cores=params['max_cores'], - task_complexity=params['task_complexity'] + max_cores=params['max_workers'], + task_complexity='medium' ) print(f"\n发现 {len(analyzer.cif_files)} 个CIF文件") - if params['use_slurm']: - # SLURM模式 - output_dir = input("输出目录 [默认: ./slurm_output]: ").strip() or "./slurm_output" - job_id = analyzer.analyze_slurm(output_dir=output_dir) - print(f"\n✅ SLURM作业已提交: {job_id}") - print(f" 使用 'squeue -j {job_id}' 查看状态") - print(f" 结果将保存到: {output_dir}") - else: - # 本地模式 - report = analyzer.analyze(show_progress=True) + # 执行分析 + report = analyzer.analyze(show_progress=True) - # 打印报告 - ReportGenerator.print_report(report, detailed=True) + # 打印报告 + ReportGenerator.print_report(report, detailed=True) - # 保存选项 - save_choice = input("\n是否保存报告? [y/N]: ").strip().lower() - if save_choice == 'y': - output_path = input("报告路径 [默认: analysis_report.json]: ").strip() - output_path = output_path or "analysis_report.json" - report.save(output_path) - print(f"✅ 报告已保存到: {output_path}") + # 保存选项 + save_choice = input("\n是否保存报告? [y/N]: ").strip().lower() + if save_choice == 'y': + output_path = input("报告路径 [默认: analysis_report.json]: ").strip() + output_path = output_path or "analysis_report.json" + report.save(output_path) + print(f"✅ 报告已保存到: {output_path}") - # CSV导出 - csv_choice = input("是否导出详细CSV? [y/N]: ").strip().lower() - if csv_choice == 'y': - csv_path = input("CSV路径 [默认: analysis_details.csv]: ").strip() - csv_path = csv_path or "analysis_details.csv" - ReportGenerator.export_to_csv(report, csv_path) + # CSV导出 + csv_choice = input("是否导出详细CSV? [y/N]: ").strip().lower() + if csv_choice == 'y': + csv_path = input("CSV路径 [默认: analysis_details.csv]: ").strip() + csv_path = csv_path or "analysis_details.csv" + ReportGenerator.export_to_csv(report, csv_path) - print("\n✅ 分析完成!") + # 生成最终数据库 + process_choice = input("\n是否生成最终可用的数据库(扩胞+添加化合价)? [Y/n]: ").strip().lower() + if process_choice == 'n': + print("\n已跳过扩胞处理,可稍后继续") + return params + + # 输出目录设置 + print("\n输出目录设置:") + flat_dir = input(" 原始格式输出目录 [默认: workspace/processed]: ").strip() + flat_dir = flat_dir or "workspace/processed" + + analysis_dir = input(" 分析格式输出目录 [默认: workspace/data]: ").strip() + analysis_dir = analysis_dir or "workspace/data" + + # 扩胞保存数量 + keep_input = input("\n扩胞结构保存数量 [默认: 1]: ").strip() + keep_number = int(keep_input) if keep_input.isdigit() and int(keep_input) > 0 else 1 + + # 扩胞精度选择 + print("\n扩胞计算精度:") + print(" 1. 高精度 (精确分数)") + print(" 2. 普通精度 (分母≤100)") + print(" 3. 低精度 (分母≤10) [默认]") + print(" 4. 极低精度 (分母≤5)") + precision_choice = input("请选择 [1/2/3/4]: ").strip() + calculate_type = { + '1': 'high', '2': 'normal', '3': 'low', '4': 'very_low', '': 'low' + }.get(precision_choice, 'low') + + # 获取可处理文件 + processable = report.get_processable_files(include_needs_expansion=True) + + if not processable: + print("⚠️ 没有可处理的文件") + return params + + print(f"\n发现 {len(processable)} 个可处理的文件") + + # 准备文件列表和扩胞标记 + input_files = [info.file_path for info in processable] + needs_expansion_flags = [info.needs_expansion for info in processable] + anion_types_list = [info.anion_types for info in processable] + + direct_count = sum(1 for f in needs_expansion_flags if not f) + expansion_count = sum(1 for f in needs_expansion_flags if f) + print(f" - 可直接处理: {direct_count}") + print(f" - 需要扩胞: {expansion_count}") + print(f" - 扩胞保存数: {keep_number}") + + confirm = input("\n确认开始处理? [Y/n]: ").strip().lower() + if confirm == 'n': + print("\n已取消处理") + return params + + print("\n" + "═" * 60) + print("开始处理结构文件...") + print("═" * 60) + + # 创建处理器 + processor = StructureProcessor( + calculate_type=calculate_type, + keep_number=keep_number, + target_cation=params['target_cation'] + ) + + # 创建输出目录 + os.makedirs(flat_dir, exist_ok=True) + os.makedirs(analysis_dir, exist_ok=True) + + results = [] + total = len(input_files) + + import shutil + for i, (input_file, needs_exp, anion_types) in enumerate( + zip(input_files, needs_expansion_flags, anion_types_list) + ): + print(f"\r处理进度: {i+1}/{total} - {os.path.basename(input_file)}", end="") + + # 处理文件到原始格式目录 + result = processor.process_file(input_file, flat_dir, needs_exp) + results.append(result) + + if result.success: + # 同时保存到分析格式目录 + # 按阴离子类型创建子目录 + anion_key = '+'.join(sorted(anion_types)) if anion_types else 'other' + anion_dir = os.path.join(analysis_dir, anion_key) + os.makedirs(anion_dir, exist_ok=True) + + # 获取基础文件名 + base_name = os.path.splitext(os.path.basename(input_file))[0] + + # 创建以文件名命名的子目录 + file_dir = os.path.join(anion_dir, base_name) + os.makedirs(file_dir, exist_ok=True) + + # 复制生成的文件到分析格式目录 + for output_file in result.output_files: + dst_path = os.path.join(file_dir, os.path.basename(output_file)) + shutil.copy2(output_file, dst_path) + + print() + + # 统计结果 + success_count = sum(1 for r in results if r.success) + fail_count = sum(1 for r in results if not r.success) + total_output = sum(len(r.output_files) for r in results if r.success) + + print("\n" + "-" * 60) + print("【处理结果统计】") + print("-" * 60) + print(f" 成功处理: {success_count}") + print(f" 处理失败: {fail_count}") + print(f" 生成文件: {total_output}") + print(f"\n 原始格式目录: {flat_dir}") + print(f" 分析格式目录: {analysis_dir}") + print(f" └── 结构: data/阴离子类型/文件名/文件名.cif") + + # 显示失败的文件 + if fail_count > 0: + print("\n失败的文件:") + for r in results: + if not r.success: + print(f" - {os.path.basename(r.input_file)}: {r.error_message}") + + print("\n✅ Step 1 完成!") + return params + + +def run_step2_zeo_analysis(params: dict, ws_manager: WorkspaceManager = None, workspace_info = None): + """ + Step 2-4: Zeo++ Voronoi 分析 + + Args: + params: 参数字典,包含 target_cation 等 + ws_manager: 工作区管理器(可选,如果已有则直接使用) + workspace_info: 工作区信息(可选,如果已有则直接使用) + """ + + print("\n" + "═" * 60) + print("【Step 2-4: Zeo++ Voronoi 分析】") + print("═" * 60) + + # 如果没有传入 ws_manager,则创建新的 + if ws_manager is None: + # 工作区路径 + workspace_path = input("\n工作区路径 [默认: workspace]: ").strip() or "workspace" + tool_dir = input("工具目录路径 [默认: tool]: ").strip() or "tool" + + # 创建工作区管理器 + ws_manager = WorkspaceManager( + workspace_path=workspace_path, + tool_dir=tool_dir, + target_cation=params.get('target_cation', 'Li') + ) + + # 如果没有传入 workspace_info,则检查现有工作区 + if workspace_info is None: + existing = ws_manager.check_existing_workspace() + + if existing and existing.total_structures > 0: + ws_manager.print_workspace_summary(existing) + workspace_info = existing + else: + print("⚠️ 工作区数据目录不存在或为空") + print(f" 请先运行 Step 1 生成处理后的数据到: {ws_manager.data_dir}") + return + + # 检查是否需要创建软链接 + if workspace_info.linked_structures < workspace_info.total_structures: + print("\n正在创建配置文件软链接...") + workspace_info = ws_manager.setup_workspace(force_relink=False) + + # 获取计算任务 + tasks = ws_manager.get_computation_tasks(workspace_info) + + if not tasks: + print("⚠️ 没有找到可计算的任务") + return + + print(f"\n发现 {len(tasks)} 个计算任务") + + # Zeo++ 环境配置 + print("\n" + "-" * 50) + print("【Zeo++ 计算配置】") + + default_zeo_env = "/cluster/home/koko125/anaconda3/envs/zeo" + zeo_env = input(f"Zeo++ Conda环境 [默认: {default_zeo_env}]: ").strip() + zeo_env = zeo_env or default_zeo_env + + # SLURM 配置 + partition = input("SLURM分区 [默认: cpu]: ").strip() or "cpu" + + max_concurrent_input = input("最大并发任务数 [默认: 50]: ").strip() + max_concurrent = int(max_concurrent_input) if max_concurrent_input.isdigit() else 50 + + time_limit = input("单任务时间限制 [默认: 2:00:00]: ").strip() or "2:00:00" + + # 创建配置 + zeo_config = ZeoConfig( + conda_env=zeo_env, + partition=partition, + max_concurrent=max_concurrent, + time_limit=time_limit + ) + + # 确认执行 + print("\n" + "-" * 50) + print("【计算任务确认】") + print(f" 总任务数: {len(tasks)}") + print(f" Conda环境: {zeo_config.conda_env}") + print(f" SLURM分区: {zeo_config.partition}") + print(f" 最大并发: {zeo_config.max_concurrent}") + print(f" 时间限制: {zeo_config.time_limit}") + + confirm = input("\n确认提交计算任务? [Y/n]: ").strip().lower() + if confirm == 'n': + print("已取消") + return + + # 创建执行器并运行 + executor = ZeoExecutor(zeo_config) + + log_dir = os.path.join(ws_manager.workspace_path, "slurm_logs") + results = executor.run_batch(tasks, output_dir=log_dir) + + # 打印结果摘要 + executor.print_results_summary(results) + + # 保存结果 + results_file = os.path.join(ws_manager.workspace_path, "zeo_results.json") + results_data = [ + { + 'task_id': r.task_id, + 'structure_name': r.structure_name, + 'cif_path': r.cif_path, + 'success': r.success, + 'output_files': r.output_files, + 'error_message': r.error_message + } + for r in results + ] + with open(results_file, 'w') as f: + json.dump(results_data, f, indent=2) + print(f"\n结果已保存到: {results_file}") + + print("\n✅ Step 2-4 完成!") + + +def run_step5_result_processing(workspace_path: str = "workspace"): + """ + Step 5: 结果处理与筛选 + + Args: + workspace_path: 工作区路径 + """ + print("\n" + "═" * 60) + print("【Step 5: 结果处理与筛选】") + print("═" * 60) + + # 创建结果处理器 + processor = ResultProcessor(workspace_path=workspace_path) + + # 获取筛选条件 + print("\n设置筛选条件:") + print(" (直接回车使用默认值,输入 0 表示不限制)") + + # 最小渗透直径(默认改为 1.0) + perc_input = input(" 最小渗透直径 (Å) [默认: 1.0]: ").strip() + min_percolation = float(perc_input) if perc_input else 1.0 + + # 最小 d 值 + d_input = input(" 最小 d 值 [默认: 2.0]: ").strip() + min_d = float(d_input) if d_input else 2.0 + + # 最大节点长度 + node_input = input(" 最大节点长度 (Å) [默认: 不限制]: ").strip() + max_node = float(node_input) if node_input else float('inf') + + # 创建筛选条件 + criteria = FilterCriteria( + min_percolation_diameter=min_percolation, + min_d_value=min_d, + max_node_length=max_node + ) + + # 执行处理 + results, stats = processor.process_and_filter( + criteria=criteria, + save_csv=True, + copy_passed=True + ) + + # 打印摘要 + processor.print_summary(results, stats) + + # 显示输出位置 + print("\n输出文件位置:") + print(f" 汇总 CSV: {workspace_path}/results/summary.csv") + print(f" 分类 CSV: {workspace_path}/results/阴离子类型/阴离子类型.csv") + print(f" 通过筛选: {workspace_path}/passed/阴离子类型/结构名/") + + print("\n✅ Step 5 完成!") + + +def main(): + """主函数""" + print_banner() + + # 环境检测 + env = detect_and_show_environment() + + # 询问目标阳离子 + cation = input("\n🎯 请输入目标阳离子 [默认: Li]: ").strip() or "Li" + + # 检测工作流程状态 + status = detect_workflow_status(workspace_path="workspace", target_cation=cation) + print_workflow_status(status) + + # 获取用户选择 + choice = get_user_choice(status) + + if choice == 'exit': + print("\n👋 再见!") + return + + # 根据选择执行相应步骤 + params = {'target_cation': cation} + + if choice == 'step1': + # 从头开始 + params = run_step1_database_analysis(env, cation) + if params is None: + return + + # 询问是否继续 + continue_choice = input("\n是否继续进行 Zeo++ 计算? [Y/n]: ").strip().lower() + if continue_choice == 'n': + print("\n可稍后运行程序继续 Zeo++ 计算") + return + + # 重新检测状态 + status = detect_workflow_status(workspace_path="workspace", target_cation=cation) + run_step2_zeo_analysis(params, status['ws_manager'], status['workspace_info']) + + # 询问是否继续筛选 + continue_choice = input("\n是否继续进行结果筛选? [Y/n]: ").strip().lower() + if continue_choice == 'n': + print("\n可稍后运行程序继续结果筛选") + return + + run_step5_result_processing("workspace") + + elif choice == 'step2': + # 从 Zeo++ 计算开始 + run_step2_zeo_analysis(params, status['ws_manager'], status['workspace_info']) + + # 询问是否继续筛选 + continue_choice = input("\n是否继续进行结果筛选? [Y/n]: ").strip().lower() + if continue_choice == 'n': + print("\n可稍后运行程序继续结果筛选") + return + + run_step5_result_processing("workspace") + + elif choice == 'step5': + # 直接进行结果筛选 + run_step5_result_processing("workspace") + + print("\n✅ 全部完成!") if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/readme.md b/readme.md index 2f177ff..0c35ce3 100644 --- a/readme.md +++ b/readme.md @@ -1,4 +1,4 @@ -# 高通量筛选与扩胞项目 +# 高通量筛选与扩胞项目 v2.3 ## 环境配置需求 @@ -12,9 +12,29 @@ ### 2. screen 环境 (用于逻辑筛选与数据处理) * **Python**: 3.11.4 * **核心库**: `pymatgen==2024.11.13`, `pandas` (新增,用于处理CSV) +* **路径**: `/cluster/home/koko125/anaconda3/envs/screen` ## 快速开始 +### 方式一:使用新版主程序(推荐) + +```bash +# 激活 screen 环境 +conda activate /cluster/home/koko125/anaconda3/envs/screen + +# 运行主程序 +python main.py +``` + +主程序提供交互式界面,支持: +- 数据库分析与筛选 +- 本地多进程并行 +- SLURM 直接提交(无需生成脚本文件) +- 实时进度条显示 +- 扩胞处理与化合价添加 + +### 方式二:传统方式 + 1. **数据准备**: * 如果数据来源为 **Materials Project (MP)**,请将 CIF 文件放入 `data/input_pre`。 * 如果数据来源为 **ICSD**,请直接将 CIF 文件放入 `data/input`。 @@ -25,6 +45,52 @@ bash main.sh ``` +## 新版功能特性 (v2.1) + +### 执行模式 + +1. **本地多进程模式** (`local`) + - 使用 Python multiprocessing 在本地并行执行 + - 适合小规模任务或测试 + +2. **SLURM 直接提交模式** (`slurm`) + - 直接在 Python 中提交 SLURM 作业 + - 无需生成脚本文件 + - 实时监控作业状态和进度 + - 适合大规模高通量计算 + +### 进度显示 + +``` +分析CIF文件: |████████████████████░░░░░░░░░░░░░░░░░░░░| 500/1000 (50.0%) [0:05:23<0:05:20, 1.6it/s] ✓480 ✗20 +``` + +### 输出格式 + +处理后的文件支持两种输出格式: + +1. **原始格式**(平铺) + ``` + workspace/processed/ + ├── 1514027.cif + ├── 1514072.cif + └── ... + ``` + +2. **分析格式**(按阴离子分类) + ``` + workspace/data/ + ├── O/ + │ ├── 1514027/ + │ │ └── 1514027.cif + │ └── 1514072/ + │ └── 1514072.cif + ├── S/ + │ └── ... + └── Cl+O/ + └── ... + ``` + ## 处理流程详解 ### Stage 1: 预处理与基础筛选 (Step 1) @@ -51,9 +117,12 @@ --- -## 扩胞逻辑 (Step 5 - 待后续执行) +## 扩胞逻辑 (Step 5) -目前扩胞逻辑维持原状,基于筛选后的结构进行处理。 +扩胞处理已集成到新版主程序中,支持: +- 自动计算扩胞因子 +- 可选保存数量(当为1时不加后缀) +- 自动添加化合价信息 ### 算法分解 1. **读取结构**: 解析 CIF 文件。 @@ -71,4 +140,276 @@ ### 假设条件 * 只考虑两个原子在同一位置上的共占位情况。 -* 不考虑 Li 原子的共占位情况,对 Li 原子不做处理。 \ No newline at end of file +* 不考虑 Li 原子的共占位情况,对 Li 原子不做处理。 + +## 项目结构 + +``` +screen/ +├── main.py # 主入口(新版) +├── main.sh # 传统脚本入口 +├── readme.md # 本文档 +├── config/ # 配置文件 +│ ├── settings.yaml +│ └── valence_states.yaml +├── src/ # 源代码 +│ ├── analysis/ # 分析模块 +│ │ ├── database_analyzer.py +│ │ ├── report_generator.py +│ │ ├── structure_inspector.py +│ │ └── worker.py +│ ├── core/ # 核心模块 +│ │ ├── executor.py # 任务执行器(新) +│ │ ├── scheduler.py # 调度器 +│ │ └── progress.py # 进度管理 +│ ├── preprocessing/ # 预处理模块 +│ │ ├── processor.py # 结构处理器 +│ │ └── ... +│ └── utils/ # 工具函数 +├── py/ # 传统脚本 +├── tool/ # 工具和配置 +│ ├── analyze_voronoi_nodes.py +│ └── Li/ # 化合价配置 +└── workspace/ # 工作区 + ├── data/ # 分析格式输出 + └── processed/ # 原始格式输出 +``` + +## API 使用示例 + +### 使用执行器 + +```python +from src.core.executor import create_executor, TaskExecutor + +# 创建执行器 +executor = create_executor( + mode="slurm", # 或 "local" + max_workers=32, + conda_env="/cluster/home/koko125/anaconda3/envs/screen" +) + +# 定义任务 +tasks = [ + (file_path, "Li", {"O", "S"}) + for file_path in cif_files +] + +# 执行 +from src.analysis.worker import analyze_single_file +results = executor.run(tasks, analyze_single_file, desc="分析CIF文件") +``` + +### 使用数据库分析器 + +```python +from src.analysis.database_analyzer import DatabaseAnalyzer + +analyzer = DatabaseAnalyzer( + database_path="/path/to/cif/files", + target_cation="Li", + target_anions={"O", "S", "Cl", "Br"}, + anion_mode="all" +) + +report = analyzer.analyze(show_progress=True) +report.save("analysis_report.json") +``` + +### 使用 Zeo++ 执行器 + +```python +from src.computation.workspace_manager import WorkspaceManager +from src.computation.zeo_executor import ZeoExecutor, ZeoConfig + +# 设置工作区 +ws_manager = WorkspaceManager( + workspace_path="workspace", + tool_dir="tool", + target_cation="Li" +) + +# 创建软链接 +workspace_info = ws_manager.setup_workspace() + +# 获取计算任务 +tasks = ws_manager.get_computation_tasks(workspace_info) + +# 配置 Zeo++ 执行器 +config = ZeoConfig( + conda_env="/cluster/home/koko125/anaconda3/envs/zeo", + partition="cpu", + max_concurrent=50, + time_limit="2:00:00" +) + +# 执行计算 +executor = ZeoExecutor(config) +results = executor.run_batch(tasks, output_dir="slurm_logs") +executor.print_results_summary(results) +``` + +## 新版功能特性 (v2.2) + +### Zeo++ Voronoi 分析 + +新增 SLURM 作业数组支持,可高效调度大量 Zeo++ 计算任务: + +1. **自动工作区管理** + - 检测现有工作区数据 + - 自动创建配置文件软链接 + - 按阴离子类型组织目录结构 + +2. **SLURM 作业数组** + - 使用 `--array` 参数批量提交任务 + - 支持最大并发数限制(如 `%50`) + - 自动分批处理超大任务集 + +3. **实时进度监控** + - 通过状态文件跟踪任务完成情况 + - 支持 Ctrl+C 中断监控(作业继续运行) + - 自动收集输出文件 + +### 工作流程 + +``` +Step 1: 数据库分析 + ↓ +Step 1.5: 扩胞处理 + 化合价添加 + ↓ +Step 2-4: Zeo++ Voronoi 分析 + ├── 创建软链接 (yaml 配置 + 计算脚本) + ├── 提交 SLURM 作业数组 + └── 监控进度并收集结果 + ↓ +Step 5: 结果处理与筛选 + ├── 从 log.txt 提取关键参数 + ├── 汇总到 CSV 文件 + ├── 应用筛选条件 + └── 复制通过筛选的结构到 passed/ 目录 +``` + +### 筛选条件 + +Step 5 支持以下筛选条件: +- **最小渗透直径** (Percolation Diameter): 默认 1.0 Å +- **最小 d 值** (Minimum of d): 默认 2.0 +- **最大节点长度** (Maximum Node Length): 默认不限制 + +### 日志输出 + +Zeo++ 计算的输出会重定向到每个结构目录下的 `log.txt` 文件: +```bash +python analyze_voronoi_nodes.py *.cif -i O.yaml > log.txt 2>&1 +``` + +日志中包含的关键信息: +- `Percolation diameter (A): X.XX` - 渗透直径 +- `the minium of d\nX.XX` - 最小 d 值 +- `Maximum node length detected: X.XX A` - 最大节点长度 + +### 目录结构 + +``` +workspace/ +├── data/ # 分析格式数据 +│ ├── O/ # 氧化物 +│ │ ├── O.yaml -> tool/Li/O.yaml +│ │ ├── analyze_voronoi_nodes.py -> tool/analyze_voronoi_nodes.py +│ │ ├── 1514027/ +│ │ │ ├── 1514027.cif +│ │ │ ├── 1514027_all_accessed_node.cif # Zeo++ 输出 +│ │ │ ├── 1514027_bond_valence_filtered.cif +│ │ │ └── 1514027_bv_info.csv +│ │ └── ... +│ ├── S/ # 硫化物 +│ └── Cl+O/ # 复合阴离子 +├── processed/ # 原始格式数据 +├── slurm_logs/ # SLURM 日志 +│ ├── tasks.json +│ ├── submit_array.sh +│ ├── task_0.out +│ ├── task_0.err +│ ├── status_0.txt +│ └── ... +└── zeo_results.json # 计算结果汇总 +``` + +### 配置文件说明 + +`tool/Li/O.yaml` 示例: +```yaml +SPECIE: Li+ +ANION: O +PERCO_R: 0.5 +NEIGHBOR: 1.8 +LONG: 2.2 +``` + +参数说明: +- `SPECIE`: 目标扩散离子(带电荷) +- `ANION`: 阴离子类型 +- `PERCO_R`: 渗透半径阈值 +- `NEIGHBOR`: 邻近距离阈值 +- `LONG`: 长节点判定阈值 + +## 新版功能特性 (v2.3) + +### 断点续做功能 + +v2.3 新增智能断点续做功能,支持从任意步骤继续执行: + +1. **自动状态检测** + - 启动时自动检测工作流程状态 + - 检测 `workspace/data/` 是否存在且有结构 → 判断 Step 1 是否完成 + - 检测结构目录下是否有 `log.txt` → 判断 Zeo++ 计算是否完成 + - 如果 50% 以上结构有 log.txt,认为 Zeo++ 计算已完成 + +2. **智能流程跳转** + - 如果已完成 Zeo++ 计算 → 可直接进行筛选 + - 如果已完成扩胞处理 → 可直接进行 Zeo++ 计算 + - 从后往前检测,自动跳过已完成的步骤 + +3. **分步执行与中断** + - 每个大步骤完成后询问是否继续 + - 支持中途退出,下次运行时自动检测进度 + - 三大步骤:扩胞与加化合价 → Zeo++ 计算 → 筛选 + +### 工作流程状态示例 + +``` + 工作流程状态检测 + +检测到现有工作区数据: + - 结构总数: 1234 + - 已完成 Zeo++ 计算: 1200 (97.2%) + - 未完成 Zeo++ 计算: 34 + +当前状态: Zeo++ 计算已完成 + +可选操作: + [1] 直接进行结果筛选 (Step 5) + [2] 重新运行 Zeo++ 计算 (Step 2-4) + [3] 从头开始 (Step 1) + [0] 退出 + +请选择 [1]: +``` + +### 使用场景 + +1. **首次运行** + - 从 Step 1 开始完整执行 + - 每步完成后可选择继续或退出 + +2. **中断后继续** + - 自动检测已完成的步骤 + - 提供从当前进度继续的选项 + +3. **重新筛选** + - 修改筛选条件后 + - 可直接运行 Step 5 而无需重新计算 + +4. **部分重算** + - 如需重新计算部分结构 + - 可选择重新运行 Zeo++ 计算 diff --git a/src/__init__.py b/src/__init__.py index e69de29..3a1c37e 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -0,0 +1,12 @@ +""" +高通量筛选与扩胞项目 - 源代码包 +""" + +from . import analysis +from . import core +from . import preprocessing +from . import computation +from . import utils + +__version__ = "2.2.0" +__all__ = ['analysis', 'core', 'preprocessing', 'computation', 'utils'] diff --git a/src/analysis/database_analyzer.py b/src/analysis/database_analyzer.py index 9c0cb5a..7e9e179 100644 --- a/src/analysis/database_analyzer.py +++ b/src/analysis/database_analyzer.py @@ -71,6 +71,151 @@ class DatabaseReport: }) expansion_factor_distribution: Dict[int, int] = field(default_factory=dict) + def to_dict(self) -> dict: + """转换为可序列化的字典""" + from dataclasses import fields as dataclass_fields + + def convert_value(val): + """递归转换值为可序列化类型""" + if isinstance(val, set): + return list(val) + elif isinstance(val, dict): + return {k: convert_value(v) for k, v in val.items()} + elif isinstance(val, list): + return [convert_value(item) for item in val] + elif hasattr(val, '__dataclass_fields__'): + # 处理 dataclass 对象 + return {k: convert_value(v) for k, v in asdict(val).items()} + else: + return val + + result = {} + for f in dataclass_fields(self): + value = getattr(self, f.name) + result[f.name] = convert_value(value) + + return result + + def save(self, path: str): + """保存报告到JSON文件""" + with open(path, 'w', encoding='utf-8') as f: + json.dump(self.to_dict(), f, indent=2, ensure_ascii=False) + print(f"✅ 报告已保存到: {path}") + + @classmethod + def load(cls, path: str) -> 'DatabaseReport': + """从JSON文件加载报告""" + with open(path, 'r', encoding='utf-8') as f: + d = json.load(f) + + # 处理 set 类型 + if 'target_anions' in d: + d['target_anions'] = set(d['target_anions']) + # 处理 StructureInfo 列表(简化处理,不恢复完整对象) + if 'all_structures' in d: + d['all_structures'] = [] + + return cls(**d) + + def get_processable_files(self, include_needs_expansion: bool = True) -> List[StructureInfo]: + """ + 获取可处理的文件列表 + + Args: + include_needs_expansion: 是否包含需要扩胞的文件 + + Returns: + 可处理的 StructureInfo 列表 + """ + result = [] + for info in self.all_structures: + if info is None or not info.is_valid: + continue + if not info.contains_target_cation: + continue + if not info.can_process: + continue + if not include_needs_expansion and info.needs_expansion: + continue + result.append(info) + return result + + def copy_processable_files( + self, + output_dir: str, + include_needs_expansion: bool = True, + organize_by_anion: bool = True + ) -> Dict[str, int]: + """ + 将可处理的CIF文件复制到工作区 + + Args: + output_dir: 输出目录(如 workspace/data) + include_needs_expansion: 是否包含需要扩胞的文件 + organize_by_anion: 是否按阴离子类型组织子目录 + + Returns: + 复制统计信息 {类别: 数量} + """ + import shutil + + # 创建输出目录 + os.makedirs(output_dir, exist_ok=True) + + # 获取可处理文件 + processable = self.get_processable_files(include_needs_expansion) + + stats = { + 'direct': 0, # 可直接处理 + 'needs_expansion': 0, # 需要扩胞 + 'total': 0 + } + + # 按类型创建子目录 + if organize_by_anion: + anion_dirs = {} + + for info in processable: + # 确定目标目录 + if organize_by_anion and info.anion_types: + # 使用主要阴离子作为目录名 + anion_key = '+'.join(sorted(info.anion_types)) + if anion_key not in anion_dirs: + anion_dir = os.path.join(output_dir, anion_key) + os.makedirs(anion_dir, exist_ok=True) + anion_dirs[anion_key] = anion_dir + target_dir = anion_dirs[anion_key] + else: + target_dir = output_dir + + # 进一步按处理类型分类 + if info.needs_expansion: + sub_dir = os.path.join(target_dir, 'needs_expansion') + stats['needs_expansion'] += 1 + else: + sub_dir = os.path.join(target_dir, 'direct') + stats['direct'] += 1 + + os.makedirs(sub_dir, exist_ok=True) + + # 复制文件 + src_path = info.file_path + dst_path = os.path.join(sub_dir, info.file_name) + + try: + shutil.copy2(src_path, dst_path) + stats['total'] += 1 + except Exception as e: + print(f"⚠️ 复制失败 {info.file_name}: {e}") + + # 打印统计 + print(f"\n📁 文件已复制到: {output_dir}") + print(f" 可直接处理: {stats['direct']}") + print(f" 需要扩胞: {stats['needs_expansion']}") + print(f" 总计: {stats['total']}") + + return stats + class DatabaseAnalyzer: """数据库分析器 - 支持高性能并行""" @@ -232,6 +377,10 @@ class DatabaseAnalyzer: report.invalid_files += 1 continue # 无效文件不继续统计 + # 关键修复:只有当结构确实含有目标阳离子时才计入统计 + if not info.contains_target_cation: + continue # 不含目标阳离子的文件不继续统计 + report.cation_containing_count += 1 for anion in info.anion_types: @@ -317,45 +466,3 @@ class DatabaseAnalyzer: for anion, count in report.anion_distribution.items(): report.anion_ratios[anion] = \ count / report.cation_containing_count - - def to_dict(self) -> dict: - """转换为可序列化的字典""" - import json - from dataclasses import asdict, fields - - result = {} - for field in fields(self): - value = getattr(self, field.name) - - # 处理 set 类型 - if isinstance(value, set): - result[field.name] = list(value) - # 处理 StructureInfo 列表 - elif field.name == 'all_structures': - result[field.name] = [] # 不保存详细结构信息,太大 - else: - result[field.name] = value - - return result - - def save(self, path: str): - """保存报告到JSON文件""" - import json - with open(path, 'w', encoding='utf-8') as f: - json.dump(self.to_dict(), f, indent=2, ensure_ascii=False) - print(f"✅ 报告已保存到: {path}") - - @classmethod - def load(cls, path: str) -> 'DatabaseReport': - """从JSON文件加载报告""" - import json - with open(path, 'r', encoding='utf-8') as f: - d = json.load(f) - - # 处理 set 类型 - if 'target_anions' in d: - d['target_anions'] = set(d['target_anions']) - if 'all_structures' not in d: - d['all_structures'] = [] - - return cls(**d) \ No newline at end of file diff --git a/src/analysis/report_generator.py b/src/analysis/report_generator.py index 0651409..172f6d1 100644 --- a/src/analysis/report_generator.py +++ b/src/analysis/report_generator.py @@ -145,7 +145,7 @@ class ReportGenerator: info.anion_mode, info.has_oxidation_states, info.has_partial_occupancy, - info.cation_has_partial_occupancy, + info.cation_with_other_cation, # 修复:使用正确的属性名 info.anion_has_partial_occupancy, info.needs_expansion, info.is_binary_compound, @@ -156,4 +156,4 @@ class ReportGenerator: ] writer.writerow(row) - print(f"详细结果已导出到: {output_path}") \ No newline at end of file + print(f"详细结果已导出到: {output_path}") diff --git a/src/analysis/structure_inspector.py b/src/analysis/structure_inspector.py index 4b171f9..fccf918 100644 --- a/src/analysis/structure_inspector.py +++ b/src/analysis/structure_inspector.py @@ -441,43 +441,3 @@ class StructureInspector: return False except: return False - - def _evaluate_processability(self, info: StructureInfo): - """评估可处理性""" - skip_reasons = [] - - if not info.is_valid: - skip_reasons.append("无法解析CIF文件") - - if not info.contains_target_cation: - skip_reasons.append(f"不含{self.target_cation}") - - if info.anion_mode == "none": - skip_reasons.append("不含目标阴离子") - - if info.is_binary_compound: - skip_reasons.append("二元化合物") - - if info.has_radioactive_elements: - skip_reasons.append("含放射性元素") - - # 关键:目标阳离子共占位是不可处理的 - if info.cation_has_partial_occupancy: - skip_reasons.append(f"{self.target_cation}存在共占位") - - # 阴离子共占位通常也不处理 - if info.anion_has_partial_occupancy: - skip_reasons.append("阴离子存在共占位") - - if info.has_water_molecule: - skip_reasons.append("含水分子") - - # 扩胞因子过大 - if info.expansion_info.needs_expansion and not info.expansion_info.can_expand: - skip_reasons.append(info.expansion_info.skip_reason) - - if skip_reasons: - info.can_process = False - info.skip_reason = "; ".join(skip_reasons) - else: - info.can_process = True \ No newline at end of file diff --git a/src/computation/__init__.py b/src/computation/__init__.py new file mode 100644 index 0000000..665a4c2 --- /dev/null +++ b/src/computation/__init__.py @@ -0,0 +1,15 @@ +""" +计算模块:Zeo++ Voronoi 分析 +""" +from .workspace_manager import WorkspaceManager +from .zeo_executor import ZeoExecutor, ZeoConfig +from .result_processor import ResultProcessor, FilterCriteria, StructureResult + +__all__ = [ + 'WorkspaceManager', + 'ZeoExecutor', + 'ZeoConfig', + 'ResultProcessor', + 'FilterCriteria', + 'StructureResult' +] diff --git a/src/computation/result_processor.py b/src/computation/result_processor.py new file mode 100644 index 0000000..14cb4cf --- /dev/null +++ b/src/computation/result_processor.py @@ -0,0 +1,426 @@ +""" +Zeo++ 计算结果处理器:提取数据、筛选结构 +""" +import os +import re +import shutil +from typing import Dict, List, Optional, Tuple +from dataclasses import dataclass, field +import pandas as pd + + +@dataclass +class FilterCriteria: + """筛选条件""" + min_percolation_diameter: float = 1.0 # 最小渗透直径 (Å),默认 1.0 + min_d_value: float = 2.0 # 最小 d 值,默认 2.0 + max_node_length: float = float('inf') # 最大节点长度 (Å) + + +@dataclass +class StructureResult: + """单个结构的计算结果""" + structure_name: str + anion_type: str + work_dir: str + + # 提取的参数 + percolation_diameter: Optional[float] = None + min_d: Optional[float] = None + max_node_length: Optional[float] = None + + # 筛选结果 + passed_filter: bool = False + filter_reason: str = "" + + +class ResultProcessor: + """ + Zeo++ 计算结果处理器 + + 功能: + 1. 从每个结构目录的 log.txt 提取关键参数 + 2. 汇总所有结果到 CSV 文件 + 3. 根据筛选条件筛选结构 + 4. 将通过筛选的结构复制到新文件夹 + """ + + def __init__( + self, + workspace_path: str = "workspace", + data_dir: str = None, + output_dir: str = None + ): + """ + 初始化结果处理器 + + Args: + workspace_path: 工作区根目录 + data_dir: 数据目录(默认 workspace/data) + output_dir: 输出目录(默认 workspace/results) + """ + self.workspace_path = os.path.abspath(workspace_path) + self.data_dir = data_dir or os.path.join(self.workspace_path, "data") + self.output_dir = output_dir or os.path.join(self.workspace_path, "results") + + def extract_from_log(self, log_path: str) -> Tuple[Optional[float], Optional[float], Optional[float]]: + """ + 从 log.txt 中提取三个关键参数 + + Args: + log_path: log.txt 文件路径 + + Returns: + (percolation_diameter, min_d, max_node_length) + """ + if not os.path.exists(log_path): + return None, None, None + + try: + with open(log_path, 'r', encoding='utf-8') as f: + content = f.read() + except Exception: + return None, None, None + + # 正则表达式 - 与 py/extract_data.py 保持一致 + + # 1. Percolation diameter: "# Percolation diameter (A): 1.06" + re_percolation = r"Percolation diameter \(A\):\s*([\d\.]+)" + + # 2. Minimum of d: "the minium of d\n3.862140561244235" + # 注意:这是 Topological_Analysis 库输出的格式 + re_min_d = r"the minium of d\s*\n\s*([\d\.]+)" + + # 3. Maximum node length: "# Maximum node length detected: 1.332 A" + re_max_node = r"Maximum node length detected:\s*([\d\.]+)\s*A" + + # 提取数据 + match_perc = re.search(re_percolation, content) + match_d = re.search(re_min_d, content) + match_node = re.search(re_max_node, content) + + val_perc = float(match_perc.group(1)) if match_perc else None + val_d = float(match_d.group(1)) if match_d else None + val_node = float(match_node.group(1)) if match_node else None + + return val_perc, val_d, val_node + + def process_all_structures(self) -> List[StructureResult]: + """ + 处理所有结构,提取计算结果 + + Returns: + StructureResult 列表 + """ + results = [] + + if not os.path.exists(self.data_dir): + print(f"⚠️ 数据目录不存在: {self.data_dir}") + return results + + print("\n正在提取计算结果...") + + # 遍历阴离子目录 + for anion_key in os.listdir(self.data_dir): + anion_dir = os.path.join(self.data_dir, anion_key) + if not os.path.isdir(anion_dir): + continue + + # 遍历结构目录 + for struct_name in os.listdir(anion_dir): + struct_dir = os.path.join(anion_dir, struct_name) + if not os.path.isdir(struct_dir): + continue + + # 查找 log.txt + log_path = os.path.join(struct_dir, "log.txt") + + # 提取参数 + perc, min_d, max_node = self.extract_from_log(log_path) + + result = StructureResult( + structure_name=struct_name, + anion_type=anion_key, + work_dir=struct_dir, + percolation_diameter=perc, + min_d=min_d, + max_node_length=max_node + ) + + results.append(result) + + print(f" 共处理 {len(results)} 个结构") + return results + + def apply_filter( + self, + results: List[StructureResult], + criteria: FilterCriteria + ) -> List[StructureResult]: + """ + 应用筛选条件 + + Args: + results: 结构结果列表 + criteria: 筛选条件 + + Returns: + 更新后的结果列表(包含筛选状态) + """ + print("\n应用筛选条件...") + print(f" 最小渗透直径: {criteria.min_percolation_diameter} Å") + print(f" 最小 d 值: {criteria.min_d_value}") + print(f" 最大节点长度: {criteria.max_node_length} Å") + + passed_count = 0 + + for result in results: + # 检查是否有有效数据 + if result.percolation_diameter is None or result.min_d is None: + result.passed_filter = False + result.filter_reason = "数据缺失" + continue + + # 检查渗透直径 + if result.percolation_diameter < criteria.min_percolation_diameter: + result.passed_filter = False + result.filter_reason = f"渗透直径 {result.percolation_diameter:.3f} < {criteria.min_percolation_diameter}" + continue + + # 检查 d 值 + if result.min_d < criteria.min_d_value: + result.passed_filter = False + result.filter_reason = f"d 值 {result.min_d:.3f} < {criteria.min_d_value}" + continue + + # 检查节点长度(如果有数据) + if result.max_node_length is not None: + if result.max_node_length > criteria.max_node_length: + result.passed_filter = False + result.filter_reason = f"节点长度 {result.max_node_length:.3f} > {criteria.max_node_length}" + continue + + # 通过所有筛选 + result.passed_filter = True + result.filter_reason = "通过" + passed_count += 1 + + print(f" 通过筛选: {passed_count}/{len(results)}") + return results + + def save_summary_csv( + self, + results: List[StructureResult], + output_path: str = None + ) -> str: + """ + 保存汇总 CSV 文件 + + Args: + results: 结构结果列表 + output_path: 输出路径(默认 workspace/results/summary.csv) + + Returns: + CSV 文件路径 + """ + if output_path is None: + os.makedirs(self.output_dir, exist_ok=True) + output_path = os.path.join(self.output_dir, "summary.csv") + + # 构建数据 + data = [] + for r in results: + data.append({ + 'Structure': r.structure_name, + 'Anion_Type': r.anion_type, + 'Percolation_Diameter_A': r.percolation_diameter, + 'Min_d': r.min_d, + 'Max_Node_Length_A': r.max_node_length, + 'Passed_Filter': r.passed_filter, + 'Filter_Reason': r.filter_reason + }) + + df = pd.DataFrame(data) + + # 按阴离子类型和结构名排序 + df = df.sort_values(['Anion_Type', 'Structure']) + + # 保存 + os.makedirs(os.path.dirname(output_path), exist_ok=True) + df.to_csv(output_path, index=False) + + print(f"\n汇总 CSV 已保存: {output_path}") + return output_path + + def save_anion_csv( + self, + results: List[StructureResult], + output_dir: str = None + ) -> List[str]: + """ + 按阴离子类型分别保存 CSV 文件 + + Args: + results: 结构结果列表 + output_dir: 输出目录 + + Returns: + 生成的 CSV 文件路径列表 + """ + if output_dir is None: + output_dir = self.output_dir + + # 按阴离子类型分组 + anion_groups: Dict[str, List[StructureResult]] = {} + for r in results: + if r.anion_type not in anion_groups: + anion_groups[r.anion_type] = [] + anion_groups[r.anion_type].append(r) + + csv_files = [] + + for anion_type, group_results in anion_groups.items(): + # 构建数据 + data = [] + for r in group_results: + data.append({ + 'Structure': r.structure_name, + 'Percolation_Diameter_A': r.percolation_diameter, + 'Min_d': r.min_d, + 'Max_Node_Length_A': r.max_node_length, + 'Passed_Filter': r.passed_filter, + 'Filter_Reason': r.filter_reason + }) + + df = pd.DataFrame(data) + df = df.sort_values('Structure') + + # 保存到对应目录 + anion_output_dir = os.path.join(output_dir, anion_type) + os.makedirs(anion_output_dir, exist_ok=True) + + csv_path = os.path.join(anion_output_dir, f"{anion_type}.csv") + df.to_csv(csv_path, index=False) + csv_files.append(csv_path) + + print(f" {anion_type}: {len(group_results)} 个结构 -> {csv_path}") + + return csv_files + + def copy_passed_structures( + self, + results: List[StructureResult], + output_dir: str = None + ) -> int: + """ + 将通过筛选的结构复制到新文件夹 + + Args: + results: 结构结果列表 + output_dir: 输出目录(默认 workspace/passed) + + Returns: + 复制的结构数量 + """ + if output_dir is None: + output_dir = os.path.join(self.workspace_path, "passed") + + passed_results = [r for r in results if r.passed_filter] + + if not passed_results: + print("\n没有通过筛选的结构") + return 0 + + print(f"\n正在复制 {len(passed_results)} 个通过筛选的结构...") + + copied = 0 + for r in passed_results: + # 目标目录:passed/阴离子类型/结构名/ + dst_dir = os.path.join(output_dir, r.anion_type, r.structure_name) + + try: + # 如果目标已存在,先删除 + if os.path.exists(dst_dir): + shutil.rmtree(dst_dir) + + # 复制整个目录 + shutil.copytree(r.work_dir, dst_dir) + copied += 1 + + except Exception as e: + print(f" ⚠️ 复制失败 {r.structure_name}: {e}") + + print(f" 已复制 {copied} 个结构到: {output_dir}") + return copied + + def process_and_filter( + self, + criteria: FilterCriteria = None, + save_csv: bool = True, + copy_passed: bool = True + ) -> Tuple[List[StructureResult], Dict]: + """ + 完整的处理流程:提取数据 -> 筛选 -> 保存 CSV -> 复制通过的结构 + + Args: + criteria: 筛选条件(如果为 None,则不筛选) + save_csv: 是否保存 CSV + copy_passed: 是否复制通过筛选的结构 + + Returns: + (结果列表, 统计信息字典) + """ + # 1. 提取所有结构的计算结果 + results = self.process_all_structures() + + if not results: + return results, {'total': 0, 'passed': 0, 'failed': 0} + + # 2. 应用筛选条件 + if criteria is not None: + results = self.apply_filter(results, criteria) + + # 3. 保存 CSV + if save_csv: + print("\n保存结果 CSV...") + self.save_summary_csv(results) + self.save_anion_csv(results) + + # 4. 复制通过筛选的结构 + if copy_passed and criteria is not None: + self.copy_passed_structures(results) + + # 统计 + stats = { + 'total': len(results), + 'passed': sum(1 for r in results if r.passed_filter), + 'failed': sum(1 for r in results if not r.passed_filter), + 'missing_data': sum(1 for r in results if r.filter_reason == "数据缺失") + } + + return results, stats + + def print_summary(self, results: List[StructureResult], stats: Dict): + """打印结果摘要""" + print("\n" + "=" * 60) + print("【计算结果摘要】") + print("=" * 60) + print(f" 总结构数: {stats['total']}") + print(f" 通过筛选: {stats['passed']}") + print(f" 未通过筛选: {stats['failed']}") + print(f" 数据缺失: {stats.get('missing_data', 0)}") + + # 按阴离子类型统计 + anion_stats: Dict[str, Dict] = {} + for r in results: + if r.anion_type not in anion_stats: + anion_stats[r.anion_type] = {'total': 0, 'passed': 0} + anion_stats[r.anion_type]['total'] += 1 + if r.passed_filter: + anion_stats[r.anion_type]['passed'] += 1 + + print("\n 按阴离子类型:") + for anion, s in sorted(anion_stats.items()): + print(f" {anion}: {s['passed']}/{s['total']} 通过") + + print("=" * 60) diff --git a/src/computation/workspace_manager.py b/src/computation/workspace_manager.py new file mode 100644 index 0000000..7c8d887 --- /dev/null +++ b/src/computation/workspace_manager.py @@ -0,0 +1,288 @@ +""" +工作区管理器:管理计算工作区的创建和软链接 +""" +import os +import shutil +from pathlib import Path +from typing import Dict, List, Optional, Set, Tuple +from dataclasses import dataclass, field + + +@dataclass +class WorkspaceInfo: + """工作区信息""" + workspace_path: str + data_dir: str # workspace/data + tool_dir: str # tool 目录 + target_cation: str + target_anions: Set[str] + + # 统计信息 + total_structures: int = 0 + anion_counts: Dict[str, int] = field(default_factory=dict) + linked_structures: int = 0 # 已创建软链接的结构数 + + +class WorkspaceManager: + """ + 工作区管理器 + + 负责: + 1. 检测现有工作区 + 2. 创建软链接(yaml 配置文件和计算脚本放在每个结构目录下) + 3. 准备计算任务 + """ + + # 支持的阴离子及其配置文件 + SUPPORTED_ANIONS = {'O', 'S', 'Cl', 'Br'} + + def __init__( + self, + workspace_path: str = "workspace", + tool_dir: str = "tool", + target_cation: str = "Li" + ): + """ + 初始化工作区管理器 + + Args: + workspace_path: 工作区根目录 + tool_dir: 工具目录(包含 yaml 配置和计算脚本) + target_cation: 目标阳离子 + """ + self.workspace_path = os.path.abspath(workspace_path) + self.tool_dir = os.path.abspath(tool_dir) + self.target_cation = target_cation + + # 数据目录 + self.data_dir = os.path.join(self.workspace_path, "data") + + def check_existing_workspace(self) -> Optional[WorkspaceInfo]: + """ + 检查现有工作区 + + Returns: + WorkspaceInfo 如果存在,否则 None + """ + if not os.path.exists(self.data_dir): + return None + + # 扫描数据目录 + anion_counts = {} + total = 0 + linked = 0 + + for item in os.listdir(self.data_dir): + item_path = os.path.join(self.data_dir, item) + if os.path.isdir(item_path): + # 可能是阴离子目录(如 O, S, O+S) + # 统计其中的结构数量 + count = 0 + for sub_item in os.listdir(item_path): + sub_path = os.path.join(item_path, sub_item) + if os.path.isdir(sub_path): + # 检查是否包含 CIF 文件 + cif_files = [f for f in os.listdir(sub_path) if f.endswith('.cif')] + if cif_files: + count += 1 + # 检查是否已有软链接 + yaml_files = [f for f in os.listdir(sub_path) if f.endswith('.yaml')] + if yaml_files: + linked += 1 + + if count > 0: + anion_counts[item] = count + total += count + + if total == 0: + return None + + return WorkspaceInfo( + workspace_path=self.workspace_path, + data_dir=self.data_dir, + tool_dir=self.tool_dir, + target_cation=self.target_cation, + target_anions=set(anion_counts.keys()), + total_structures=total, + anion_counts=anion_counts, + linked_structures=linked + ) + + def setup_workspace( + self, + target_anions: Set[str] = None, + force_relink: bool = False + ) -> WorkspaceInfo: + """ + 设置工作区:在每个结构目录下创建软链接 + + 软链接规则: + - yaml 文件:使用与阴离子目录同名的 yaml(如 O 目录用 O.yaml,Cl+O 目录用 Cl+O.yaml) + - python 脚本:analyze_voronoi_nodes.py + + Args: + target_anions: 目标阴离子集合 + force_relink: 是否强制重新创建软链接 + + Returns: + WorkspaceInfo + """ + if target_anions is None: + target_anions = self.SUPPORTED_ANIONS + + # 确保数据目录存在 + if not os.path.exists(self.data_dir): + raise FileNotFoundError(f"数据目录不存在: {self.data_dir}") + + # 获取计算脚本路径 + analyze_script = os.path.join(self.tool_dir, "analyze_voronoi_nodes.py") + if not os.path.exists(analyze_script): + raise FileNotFoundError(f"计算脚本不存在: {analyze_script}") + + anion_counts = {} + total = 0 + linked = 0 + + print("\n正在设置工作区软链接...") + + # 遍历数据目录中的阴离子子目录 + for anion_key in os.listdir(self.data_dir): + anion_dir = os.path.join(self.data_dir, anion_key) + if not os.path.isdir(anion_dir): + continue + + # 确定使用哪个 yaml 配置文件 + # 使用与阴离子目录同名的 yaml 文件(如 O.yaml, Cl+O.yaml) + yaml_name = f"{anion_key}.yaml" + yaml_source = os.path.join(self.tool_dir, self.target_cation, yaml_name) + + if not os.path.exists(yaml_source): + print(f" ⚠️ 配置文件不存在: {yaml_source}") + continue + + # 统计并处理该阴离子目录下的所有结构 + count = 0 + for struct_name in os.listdir(anion_dir): + struct_dir = os.path.join(anion_dir, struct_name) + + if not os.path.isdir(struct_dir): + continue + + # 检查是否包含 CIF 文件 + cif_files = [f for f in os.listdir(struct_dir) if f.endswith('.cif')] + if not cif_files: + continue + + count += 1 + + # 在结构目录下创建软链接 + yaml_link = os.path.join(struct_dir, yaml_name) + script_link = os.path.join(struct_dir, "analyze_voronoi_nodes.py") + + # 创建 yaml 软链接 + if os.path.exists(yaml_link) or os.path.islink(yaml_link): + if force_relink: + os.remove(yaml_link) + os.symlink(yaml_source, yaml_link) + linked += 1 + else: + os.symlink(yaml_source, yaml_link) + linked += 1 + + # 创建计算脚本软链接 + if os.path.exists(script_link) or os.path.islink(script_link): + if force_relink: + os.remove(script_link) + os.symlink(analyze_script, script_link) + else: + os.symlink(analyze_script, script_link) + + if count > 0: + anion_counts[anion_key] = count + total += count + print(f" ✓ {anion_key}: {count} 个结构, 配置 -> {yaml_name}") + + print(f"\n 总计: {total} 个结构, 新建软链接: {linked}") + + return WorkspaceInfo( + workspace_path=self.workspace_path, + data_dir=self.data_dir, + tool_dir=self.tool_dir, + target_cation=self.target_cation, + target_anions=set(anion_counts.keys()), + total_structures=total, + anion_counts=anion_counts, + linked_structures=linked + ) + + def get_computation_tasks( + self, + workspace_info: WorkspaceInfo = None + ) -> List[Dict]: + """ + 获取所有计算任务 + + Returns: + 任务列表,每个任务包含: + - cif_path: CIF 文件路径 + - yaml_name: YAML 配置文件名(如 O.yaml) + - work_dir: 工作目录(结构目录) + - anion_type: 阴离子类型 + - structure_name: 结构名称 + """ + if workspace_info is None: + workspace_info = self.check_existing_workspace() + + if workspace_info is None: + return [] + + tasks = [] + + for anion_key in workspace_info.anion_counts.keys(): + anion_dir = os.path.join(self.data_dir, anion_key) + yaml_name = f"{anion_key}.yaml" + + # 遍历该阴离子目录下的所有结构 + for struct_name in os.listdir(anion_dir): + struct_dir = os.path.join(anion_dir, struct_name) + + if not os.path.isdir(struct_dir): + continue + + # 查找 CIF 文件 + cif_files = [f for f in os.listdir(struct_dir) if f.endswith('.cif')] + + # 检查是否有 yaml 软链接 + yaml_path = os.path.join(struct_dir, yaml_name) + if not os.path.exists(yaml_path): + continue + + for cif_file in cif_files: + cif_path = os.path.join(struct_dir, cif_file) + + tasks.append({ + 'cif_path': cif_path, + 'yaml_name': yaml_name, + 'work_dir': struct_dir, + 'anion_type': anion_key, + 'structure_name': struct_name, + 'cif_name': cif_file + }) + + return tasks + + def print_workspace_summary(self, workspace_info: WorkspaceInfo): + """打印工作区摘要""" + print("\n" + "=" * 60) + print("【工作区摘要】") + print("=" * 60) + print(f" 工作区路径: {workspace_info.workspace_path}") + print(f" 数据目录: {workspace_info.data_dir}") + print(f" 目标阳离子: {workspace_info.target_cation}") + print(f" 总结构数: {workspace_info.total_structures}") + print(f" 已配置软链接: {workspace_info.linked_structures}") + print() + print(" 阴离子分布:") + for anion, count in sorted(workspace_info.anion_counts.items()): + print(f" - {anion}: {count} 个结构") + print("=" * 60) diff --git a/src/computation/zeo_executor.py b/src/computation/zeo_executor.py new file mode 100644 index 0000000..48c2061 --- /dev/null +++ b/src/computation/zeo_executor.py @@ -0,0 +1,446 @@ +""" +Zeo++ 计算执行器:使用 SLURM 作业数组高效调度大量计算任务 +""" +import os +import subprocess +import time +import json +import tempfile +from typing import List, Dict, Optional, Callable, Any +from dataclasses import dataclass, field +from enum import Enum +import threading + +from ..core.progress import ProgressManager + + +@dataclass +class ZeoConfig: + """Zeo++ 计算配置""" + # 环境配置 + conda_env: str = "/cluster/home/koko125/anaconda3/envs/zeo" + + # SLURM 配置 + partition: str = "cpu" + time_limit: str = "2:00:00" # 单个任务时间限制 + memory_per_task: str = "4G" + + # 作业数组配置 + max_array_size: int = 1000 # SLURM 作业数组最大大小 + max_concurrent: int = 50 # 最大并发任务数 + + # 轮询配置 + poll_interval: float = 5.0 # 状态检查间隔(秒) + + # 过滤器配置 + filters: List[str] = field(default_factory=lambda: [ + "Ordered", "PropOxi", "VoroPerco", "Coulomb", "VoroBV", "VoroInfo", "MergeSite" + ]) + + +@dataclass +class ZeoTaskResult: + """单个任务结果""" + task_id: int + structure_name: str + cif_path: str + success: bool + output_files: List[str] = field(default_factory=list) + error_message: str = "" + duration: float = 0.0 + + +class ZeoExecutor: + """ + Zeo++ 计算执行器 + + 使用 SLURM 作业数组高效调度大量 Voronoi 分析任务 + """ + + def __init__(self, config: ZeoConfig = None): + self.config = config or ZeoConfig() + self.progress_manager = None + self._stop_event = threading.Event() + + def run_batch( + self, + tasks: List[Dict], + output_dir: str = None, + desc: str = "Zeo++ 计算" + ) -> List[ZeoTaskResult]: + """ + 批量执行 Zeo++ 计算 + + Args: + tasks: 任务列表,每个任务包含 cif_path, yaml_path, work_dir 等 + output_dir: SLURM 日志输出目录 + desc: 进度条描述 + + Returns: + ZeoTaskResult 列表 + """ + if not tasks: + print("⚠️ 没有任务需要执行") + return [] + + total = len(tasks) + + # 创建输出目录 + if output_dir is None: + output_dir = os.path.join(os.getcwd(), "slurm_logs") + os.makedirs(output_dir, exist_ok=True) + + print(f"\n{'='*60}") + print(f"【Zeo++ 批量计算】") + print(f"{'='*60}") + print(f" 总任务数: {total}") + print(f" Conda环境: {self.config.conda_env}") + print(f" SLURM分区: {self.config.partition}") + print(f" 最大并发: {self.config.max_concurrent}") + print(f" 日志目录: {output_dir}") + print(f"{'='*60}\n") + + # 保存任务列表到文件 + tasks_file = os.path.join(output_dir, "tasks.json") + with open(tasks_file, 'w') as f: + json.dump(tasks, f, indent=2) + + # 生成并提交作业数组 + if total <= self.config.max_array_size: + # 单个作业数组 + return self._submit_array_job(tasks, output_dir, desc) + else: + # 分批提交多个作业数组 + return self._submit_batched_arrays(tasks, output_dir, desc) + + def _submit_array_job( + self, + tasks: List[Dict], + output_dir: str, + desc: str + ) -> List[ZeoTaskResult]: + """提交单个作业数组""" + total = len(tasks) + + # 保存任务列表 + tasks_file = os.path.join(output_dir, "tasks.json") + with open(tasks_file, 'w') as f: + json.dump(tasks, f, indent=2) + + # 生成作业脚本 + script_content = self._generate_array_script( + tasks_file=tasks_file, + output_dir=output_dir, + array_range=f"0-{total-1}%{self.config.max_concurrent}" + ) + + script_path = os.path.join(output_dir, "submit_array.sh") + with open(script_path, 'w') as f: + f.write(script_content) + os.chmod(script_path, 0o755) + + print(f"生成作业脚本: {script_path}") + + # 提交作业 + result = subprocess.run( + ['sbatch', script_path], + capture_output=True, + text=True + ) + + if result.returncode != 0: + print(f"❌ 作业提交失败: {result.stderr}") + return [ZeoTaskResult( + task_id=i, + structure_name=t.get('structure_name', ''), + cif_path=t.get('cif_path', ''), + success=False, + error_message=f"提交失败: {result.stderr}" + ) for i, t in enumerate(tasks)] + + # 提取作业 ID + job_id = result.stdout.strip().split()[-1] + print(f"✓ 作业已提交: {job_id}") + print(f" 作业数组范围: 0-{total-1}") + print(f" 最大并发: {self.config.max_concurrent}") + + # 监控作业进度 + return self._monitor_array_job(job_id, tasks, output_dir, desc) + + def _submit_batched_arrays( + self, + tasks: List[Dict], + output_dir: str, + desc: str + ) -> List[ZeoTaskResult]: + """分批提交多个作业数组""" + total = len(tasks) + batch_size = self.config.max_array_size + num_batches = (total + batch_size - 1) // batch_size + + print(f"任务数超过作业数组限制,分 {num_batches} 批提交") + + all_results = [] + + for batch_idx in range(num_batches): + start_idx = batch_idx * batch_size + end_idx = min(start_idx + batch_size, total) + batch_tasks = tasks[start_idx:end_idx] + + batch_output_dir = os.path.join(output_dir, f"batch_{batch_idx}") + os.makedirs(batch_output_dir, exist_ok=True) + + print(f"\n--- 批次 {batch_idx + 1}/{num_batches} ---") + print(f"任务范围: {start_idx} - {end_idx - 1}") + + batch_results = self._submit_array_job( + batch_tasks, + batch_output_dir, + f"{desc} (批次 {batch_idx + 1}/{num_batches})" + ) + + # 调整任务 ID + for r in batch_results: + r.task_id += start_idx + + all_results.extend(batch_results) + + return all_results + + def _generate_array_script( + self, + tasks_file: str, + output_dir: str, + array_range: str + ) -> str: + """生成 SLURM 作业数组脚本""" + + # 获取项目根目录 + project_root = os.getcwd() + + script = f"""#!/bin/bash +#SBATCH --job-name=zeo_array +#SBATCH --partition={self.config.partition} +#SBATCH --array={array_range} +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=1 +#SBATCH --mem={self.config.memory_per_task} +#SBATCH --time={self.config.time_limit} +#SBATCH --output={output_dir}/task_%a.out +#SBATCH --error={output_dir}/task_%a.err + +# ============================================ +# Zeo++ Voronoi 分析 - 作业数组 +# ============================================ + +echo "===== 任务信息 =====" +echo "作业ID: $SLURM_JOB_ID" +echo "数组任务ID: $SLURM_ARRAY_TASK_ID" +echo "节点: $SLURM_NODELIST" +echo "开始时间: $(date)" +echo "====================" + +# ============ 环境初始化 ============ +# 加载 bashrc +if [ -f ~/.bashrc ]; then + source ~/.bashrc +fi + +# 初始化 Conda +if [ -f ~/anaconda3/etc/profile.d/conda.sh ]; then + source ~/anaconda3/etc/profile.d/conda.sh +elif [ -f /opt/anaconda3/etc/profile.d/conda.sh ]; then + source /opt/anaconda3/etc/profile.d/conda.sh +fi + +# 激活 Zeo++ 环境 +conda activate {self.config.conda_env} + +echo "" +echo "===== 环境检查 =====" +echo "Conda环境: $CONDA_DEFAULT_ENV" +echo "Python路径: $(which python)" +echo "====================" +echo "" + +# ============ 读取任务信息 ============ +TASKS_FILE="{tasks_file}" +TASK_ID=$SLURM_ARRAY_TASK_ID + +# 使用 Python 解析任务 +TASK_INFO=$(python3 -c " +import json +with open('$TASKS_FILE', 'r') as f: + tasks = json.load(f) +if $TASK_ID < len(tasks): + task = tasks[$TASK_ID] + print(task['work_dir']) + print(task['yaml_name']) +else: + print('ERROR') +") + +WORK_DIR=$(echo "$TASK_INFO" | sed -n '1p') +YAML_NAME=$(echo "$TASK_INFO" | sed -n '2p') + +if [ "$WORK_DIR" == "ERROR" ]; then + echo "错误: 任务ID $TASK_ID 超出范围" + exit 1 +fi + +echo "工作目录: $WORK_DIR" +echo "配置文件: $YAML_NAME" +echo "" + +# ============ 执行计算 ============ +cd "$WORK_DIR" + +echo "开始 Voronoi 分析..." +# 软链接已在工作目录下,直接使用相对路径 +# 将输出重定向到 log.txt 以便后续提取结果 +python analyze_voronoi_nodes.py *.cif -i "$YAML_NAME" > log.txt 2>&1 + +EXIT_CODE=$? + +# 显示日志内容(用于调试) +echo "" +echo "===== 计算日志 =====" +cat log.txt +echo "====================" + +# ============ 完成 ============ +echo "" +echo "===== 任务完成 =====" +echo "结束时间: $(date)" +echo "退出代码: $EXIT_CODE" + +# 写入状态文件 +if [ $EXIT_CODE -eq 0 ]; then + echo "SUCCESS" > "{output_dir}/status_$TASK_ID.txt" +else + echo "FAILED" > "{output_dir}/status_$TASK_ID.txt" +fi + +echo "====================" +exit $EXIT_CODE +""" + return script + + def _monitor_array_job( + self, + job_id: str, + tasks: List[Dict], + output_dir: str, + desc: str + ) -> List[ZeoTaskResult]: + """监控作业数组进度""" + total = len(tasks) + + self.progress_manager = ProgressManager(total, desc) + self.progress_manager.start() + + results = [None] * total + completed = set() + + print(f"\n监控作业进度 (每 {self.config.poll_interval} 秒检查一次)...") + print("按 Ctrl+C 可中断监控(作业将继续在后台运行)\n") + + try: + while len(completed) < total: + time.sleep(self.config.poll_interval) + + # 检查状态文件 + for i in range(total): + if i in completed: + continue + + status_file = os.path.join(output_dir, f"status_{i}.txt") + if os.path.exists(status_file): + with open(status_file, 'r') as f: + status = f.read().strip() + + task = tasks[i] + success = (status == "SUCCESS") + + # 收集输出文件 + output_files = [] + if success: + work_dir = task['work_dir'] + for f in os.listdir(work_dir): + if f.endswith(('.cif', '.csv')) and f != task['cif_name']: + output_files.append(os.path.join(work_dir, f)) + + results[i] = ZeoTaskResult( + task_id=i, + structure_name=task.get('structure_name', ''), + cif_path=task.get('cif_path', ''), + success=success, + output_files=output_files + ) + + completed.add(i) + self.progress_manager.update(success=success) + self.progress_manager.display() + + # 检查作业是否还在运行 + if not self._is_job_running(job_id) and len(completed) < total: + # 作业已结束但有任务未完成 + print(f"\n⚠️ 作业已结束,但有 {total - len(completed)} 个任务未完成") + break + + except KeyboardInterrupt: + print("\n\n⚠️ 监控已中断,作业将继续在后台运行") + print(f" 可使用 'squeue -j {job_id}' 查看作业状态") + print(f" 可使用 'scancel {job_id}' 取消作业") + + self.progress_manager.finish() + + # 填充未完成的任务 + for i in range(total): + if results[i] is None: + task = tasks[i] + results[i] = ZeoTaskResult( + task_id=i, + structure_name=task.get('structure_name', ''), + cif_path=task.get('cif_path', ''), + success=False, + error_message="任务未完成或状态未知" + ) + + return results + + def _is_job_running(self, job_id: str) -> bool: + """检查作业是否还在运行""" + try: + result = subprocess.run( + ['squeue', '-j', job_id, '-h'], + capture_output=True, + text=True, + timeout=10 + ) + return bool(result.stdout.strip()) + except Exception: + return False + + def print_results_summary(self, results: List[ZeoTaskResult]): + """打印结果摘要""" + total = len(results) + success = sum(1 for r in results if r.success) + failed = total - success + + print("\n" + "=" * 60) + print("【计算结果摘要】") + print("=" * 60) + print(f" 总任务数: {total}") + print(f" 成功: {success} ({100*success/total:.1f}%)") + print(f" 失败: {failed} ({100*failed/total:.1f}%)") + + if failed > 0 and failed <= 10: + print("\n 失败的任务:") + for r in results: + if not r.success: + print(f" - {r.structure_name}: {r.error_message}") + elif failed > 10: + print(f"\n 失败任务过多,请检查日志文件") + + print("=" * 60) diff --git a/src/core/__init__.py b/src/core/__init__.py index e69de29..45e21bd 100644 --- a/src/core/__init__.py +++ b/src/core/__init__.py @@ -0,0 +1,18 @@ +""" +核心模块:调度器、执行器和进度管理 +""" +from .scheduler import ParallelScheduler, ResourceConfig, ExecutionMode as SchedulerMode +from .executor import TaskExecutor, ExecutorConfig, ExecutionMode, TaskResult, create_executor +from .progress import ProgressManager + +__all__ = [ + 'ParallelScheduler', + 'ResourceConfig', + 'SchedulerMode', + 'TaskExecutor', + 'ExecutorConfig', + 'ExecutionMode', + 'TaskResult', + 'create_executor', + 'ProgressManager', +] diff --git a/src/core/executor.py b/src/core/executor.py new file mode 100644 index 0000000..e56d6a2 --- /dev/null +++ b/src/core/executor.py @@ -0,0 +1,431 @@ +""" +任务执行器:支持本地执行和 SLURM 直接提交 +不生成脚本文件,直接在 Python 中管理任务 +""" +import os +import subprocess +import time +import json +from typing import List, Callable, Any, Optional, Dict, Tuple +from multiprocessing import Pool, cpu_count +from dataclasses import dataclass, field +from enum import Enum +from concurrent.futures import ThreadPoolExecutor, as_completed +import threading + +from .progress import ProgressManager + + +class ExecutionMode(Enum): + """执行模式""" + LOCAL = "local" # 本地多进程 + SLURM_DIRECT = "slurm" # SLURM 直接提交(不生成脚本) + + +@dataclass +class ExecutorConfig: + """执行器配置""" + mode: ExecutionMode = ExecutionMode.LOCAL + max_workers: int = 4 + conda_env: str = "/cluster/home/koko125/anaconda3/envs/screen" + partition: str = "cpu" + time_limit: str = "7-00:00:00" + memory_per_task: str = "4G" + + # SLURM 相关 + poll_interval: float = 2.0 # 轮询间隔(秒) + max_concurrent_jobs: int = 50 # 最大并发作业数 + + +@dataclass +class TaskResult: + """任务结果""" + task_id: Any + success: bool + result: Any = None + error: str = None + duration: float = 0.0 + + +class TaskExecutor: + """ + 任务执行器 + + 支持两种模式: + 1. LOCAL: 本地多进程执行 + 2. SLURM_DIRECT: 直接提交 SLURM 作业,实时监控进度 + """ + + def __init__(self, config: ExecutorConfig = None): + self.config = config or ExecutorConfig() + self.progress_manager = None + self._stop_event = threading.Event() + + @staticmethod + def detect_environment() -> Dict[str, Any]: + """检测运行环境""" + env_info = { + 'hostname': os.uname().nodename, + 'total_cores': cpu_count(), + 'has_slurm': False, + 'slurm_partitions': [], + 'conda_env': os.environ.get('CONDA_PREFIX', ''), + } + + # 检测 SLURM + try: + result = subprocess.run( + ['sinfo', '-h', '-o', '%P %a %c %D'], + capture_output=True, text=True, timeout=5 + ) + if result.returncode == 0: + env_info['has_slurm'] = True + lines = result.stdout.strip().split('\n') + for line in lines: + parts = line.split() + if len(parts) >= 4: + partition = parts[0].rstrip('*') + avail = parts[1] + if avail == 'up': + env_info['slurm_partitions'].append(partition) + except Exception: + pass + + return env_info + + def run( + self, + tasks: List[Any], + worker_func: Callable, + desc: str = "Processing" + ) -> List[TaskResult]: + """ + 执行任务 + + Args: + tasks: 任务列表 + worker_func: 工作函数,接收单个任务,返回结果 + desc: 进度条描述 + + Returns: + TaskResult 列表 + """ + if self.config.mode == ExecutionMode.LOCAL: + return self._run_local(tasks, worker_func, desc) + elif self.config.mode == ExecutionMode.SLURM_DIRECT: + return self._run_slurm_direct(tasks, worker_func, desc) + else: + raise ValueError(f"不支持的执行模式: {self.config.mode}") + + def _run_local( + self, + tasks: List[Any], + worker_func: Callable, + desc: str + ) -> List[TaskResult]: + """本地多进程执行""" + total = len(tasks) + num_workers = min(self.config.max_workers, total) + + print(f"\n{'='*60}") + print(f"本地执行配置:") + print(f" 总任务数: {total}") + print(f" Worker数: {num_workers}") + print(f"{'='*60}\n") + + self.progress_manager = ProgressManager(total, desc) + self.progress_manager.start() + + results = [] + + if num_workers == 1: + # 单进程执行 + for i, task in enumerate(tasks): + start_time = time.time() + try: + result = worker_func(task) + duration = time.time() - start_time + results.append(TaskResult( + task_id=i, + success=True, + result=result, + duration=duration + )) + self.progress_manager.update(success=True) + except Exception as e: + duration = time.time() - start_time + results.append(TaskResult( + task_id=i, + success=False, + error=str(e), + duration=duration + )) + self.progress_manager.update(success=False) + self.progress_manager.display() + else: + # 多进程执行 + with Pool(processes=num_workers) as pool: + for i, result in enumerate(pool.imap_unordered(worker_func, tasks)): + if result is not None: + results.append(TaskResult( + task_id=i, + success=True, + result=result + )) + self.progress_manager.update(success=True) + else: + results.append(TaskResult( + task_id=i, + success=False, + error="Worker returned None" + )) + self.progress_manager.update(success=False) + self.progress_manager.display() + + self.progress_manager.finish() + return results + + def _run_slurm_direct( + self, + tasks: List[Any], + worker_func: Callable, + desc: str + ) -> List[TaskResult]: + """ + SLURM 直接提交模式 + + 注意:对于数据库分析这类快速任务,建议使用本地多进程模式 + SLURM 模式更适合耗时的计算任务(如 Zeo++ 分析) + + 这里回退到本地模式,因为 srun 在登录节点直接调用效率不高 + """ + print("\n⚠️ 注意:数据库分析阶段自动使用本地多进程模式") + print(" SLURM 模式将在后续耗时计算步骤中使用") + + # 回退到本地模式 + return self._run_local(tasks, worker_func, desc) + + +class SlurmJobManager: + """ + SLURM 作业管理器 + + 用于批量提交和监控 SLURM 作业 + """ + + def __init__(self, config: ExecutorConfig): + self.config = config + self.active_jobs = {} # job_id -> task_info + + def submit_batch( + self, + tasks: List[Tuple[str, str, set]], # (file_path, target_cation, target_anions) + output_dir: str, + desc: str = "Processing" + ) -> List[TaskResult]: + """ + 批量提交任务到 SLURM + + 使用 sbatch --wrap 直接提交,不生成脚本文件 + """ + total = len(tasks) + os.makedirs(output_dir, exist_ok=True) + + print(f"\n{'='*60}") + print(f"SLURM 批量提交:") + print(f" 总任务数: {total}") + print(f" 输出目录: {output_dir}") + print(f" Conda环境: {self.config.conda_env}") + print(f"{'='*60}\n") + + progress = ProgressManager(total, desc) + progress.start() + + results = [] + job_ids = [] + + # 提交所有任务 + for i, task in enumerate(tasks): + file_path, target_cation, target_anions = task + + # 构建 Python 命令 + anions_str = ','.join(target_anions) + python_cmd = ( + f"python -c \"" + f"import sys; sys.path.insert(0, '{os.getcwd()}'); " + f"from src.analysis.worker import analyze_single_file; " + f"result = analyze_single_file(('{file_path}', '{target_cation}', set('{anions_str}'.split(',')))); " + f"print('SUCCESS' if result and result.is_valid else 'FAILED')" + f"\"" + ) + + # 构建完整的 bash 命令 + bash_cmd = ( + f"source {os.path.dirname(self.config.conda_env)}/../../etc/profile.d/conda.sh && " + f"conda activate {self.config.conda_env} && " + f"{python_cmd}" + ) + + # 使用 sbatch --wrap 提交 + sbatch_cmd = [ + 'sbatch', + '--partition', self.config.partition, + '--ntasks', '1', + '--cpus-per-task', '1', + '--mem', self.config.memory_per_task, + '--time', '01:00:00', + '--output', os.path.join(output_dir, f'task_{i}.out'), + '--error', os.path.join(output_dir, f'task_{i}.err'), + '--wrap', bash_cmd + ] + + try: + result = subprocess.run( + sbatch_cmd, + capture_output=True, + text=True + ) + + if result.returncode == 0: + # 提取 job_id + job_id = result.stdout.strip().split()[-1] + job_ids.append((i, job_id, file_path)) + self.active_jobs[job_id] = { + 'task_index': i, + 'file_path': file_path, + 'status': 'PENDING' + } + else: + results.append(TaskResult( + task_id=i, + success=False, + error=f"提交失败: {result.stderr}" + )) + progress.update(success=False) + progress.display() + + except Exception as e: + results.append(TaskResult( + task_id=i, + success=False, + error=str(e) + )) + progress.update(success=False) + progress.display() + + print(f"\n已提交 {len(job_ids)} 个作业,等待完成...") + + # 监控作业状态 + while self.active_jobs: + time.sleep(self.config.poll_interval) + + # 检查作业状态 + completed_jobs = self._check_job_status() + + for job_id, status in completed_jobs: + job_info = self.active_jobs.pop(job_id, None) + if job_info: + task_idx = job_info['task_index'] + + if status == 'COMPLETED': + # 检查输出文件 + out_file = os.path.join(output_dir, f'task_{task_idx}.out') + success = False + if os.path.exists(out_file): + with open(out_file, 'r') as f: + content = f.read() + success = 'SUCCESS' in content + + results.append(TaskResult( + task_id=task_idx, + success=success, + result=job_info['file_path'] + )) + progress.update(success=success) + else: + # 作业失败 + err_file = os.path.join(output_dir, f'task_{task_idx}.err') + error_msg = status + if os.path.exists(err_file): + with open(err_file, 'r') as f: + error_msg = f.read()[:500] # 只取前500字符 + + results.append(TaskResult( + task_id=task_idx, + success=False, + error=error_msg + )) + progress.update(success=False) + + progress.display() + + progress.finish() + return results + + def _check_job_status(self) -> List[Tuple[str, str]]: + """检查作业状态,返回已完成的作业列表""" + if not self.active_jobs: + return [] + + job_ids = list(self.active_jobs.keys()) + + try: + result = subprocess.run( + ['sacct', '-j', ','.join(job_ids), '--format=JobID,State', '--noheader', '--parsable2'], + capture_output=True, + text=True, + timeout=30 + ) + + completed = [] + if result.returncode == 0: + for line in result.stdout.strip().split('\n'): + if line: + parts = line.split('|') + if len(parts) >= 2: + job_id = parts[0].split('.')[0] # 去掉 .batch 后缀 + status = parts[1] + + if job_id in self.active_jobs: + if status in ['COMPLETED', 'FAILED', 'CANCELLED', 'TIMEOUT', 'NODE_FAIL']: + completed.append((job_id, status)) + + return completed + + except Exception: + return [] + + +def create_executor( + mode: str = "local", + max_workers: int = None, + conda_env: str = None, + **kwargs +) -> TaskExecutor: + """ + 创建任务执行器的便捷函数 + + Args: + mode: "local" 或 "slurm" + max_workers: 最大工作进程数 + conda_env: Conda 环境路径 + **kwargs: 其他配置参数 + """ + env = TaskExecutor.detect_environment() + + if max_workers is None: + max_workers = min(env['total_cores'], 32) + + if conda_env is None: + conda_env = env.get('conda_env') or "/cluster/home/koko125/anaconda3/envs/screen" + + exec_mode = ExecutionMode.SLURM_DIRECT if mode.lower() == "slurm" else ExecutionMode.LOCAL + + config = ExecutorConfig( + mode=exec_mode, + max_workers=max_workers, + conda_env=conda_env, + **kwargs + ) + + return TaskExecutor(config) diff --git a/src/preprocessing/processor.py b/src/preprocessing/processor.py new file mode 100644 index 0000000..06c6709 --- /dev/null +++ b/src/preprocessing/processor.py @@ -0,0 +1,562 @@ +""" +结构预处理器:扩胞和添加化合价 +""" +import os +import re +import yaml +from typing import List, Dict, Optional, Tuple +from dataclasses import dataclass, field +from pymatgen.core.structure import Structure +from pymatgen.core.periodic_table import Specie +from pymatgen.core import Lattice, Species, PeriodicSite +from collections import defaultdict +from fractions import Fraction +from functools import reduce +import math +import random +import spglib +import numpy as np + + +@dataclass +class ProcessingResult: + """处理结果""" + input_file: str + output_files: List[str] = field(default_factory=list) + success: bool = False + needs_expansion: bool = False + expansion_factor: int = 1 + error_message: str = "" + + +class StructureProcessor: + """结构预处理器""" + + # 默认化合价配置 + DEFAULT_VALENCE_PATH = os.path.join( + os.path.dirname(__file__), '..', '..', 'tool', 'valence_states.yaml' + ) + + def __init__( + self, + valence_yaml_path: str = None, + calculate_type: str = 'low', + max_expansion_factor: int = 64, + keep_number: int = 3, + target_cation: str = "Li" + ): + """ + 初始化处理器 + + Args: + valence_yaml_path: 化合价配置文件路径 + calculate_type: 扩胞计算精度 ('high', 'normal', 'low', 'very_low') + max_expansion_factor: 最大扩胞因子 + keep_number: 保留的扩胞结构数量 + target_cation: 目标阳离子 + """ + self.valence_yaml_path = valence_yaml_path or self.DEFAULT_VALENCE_PATH + self.calculate_type = calculate_type + self.max_expansion_factor = max_expansion_factor + self.keep_number = keep_number + self.target_cation = target_cation + self.explict_element = [target_cation, f"{target_cation}+"] + + # 加载化合价配置 + self.valences = self._load_valences() + + def _load_valences(self) -> Dict[str, int]: + """加载化合价配置""" + if os.path.exists(self.valence_yaml_path): + with open(self.valence_yaml_path, 'r') as f: + return yaml.safe_load(f) + return {} + + def process_file( + self, + input_path: str, + output_dir: str, + needs_expansion: bool = False + ) -> ProcessingResult: + """ + 处理单个CIF文件 + + Args: + input_path: 输入文件路径 + output_dir: 输出目录 + needs_expansion: 是否需要扩胞 + + Returns: + ProcessingResult: 处理结果 + """ + result = ProcessingResult(input_file=input_path) + + try: + # 读取结构 + structure = Structure.from_file(input_path) + base_name = os.path.splitext(os.path.basename(input_path))[0] + + # 检查是否需要扩胞 + occupation_list = self._process_cif_file(structure) + + if occupation_list and needs_expansion: + # 需要扩胞处理 + result.needs_expansion = True + output_files = self._expand_and_save( + structure, occupation_list, base_name, output_dir + ) + result.output_files = output_files + result.expansion_factor = occupation_list[0].get('denominator', 1) if occupation_list else 1 + else: + # 不需要扩胞,直接添加化合价 + output_path = os.path.join(output_dir, f"{base_name}.cif") + self._add_oxidation_states(structure) + structure.to(filename=output_path) + result.output_files = [output_path] + + result.success = True + + except Exception as e: + result.success = False + result.error_message = str(e) + + return result + + def _process_cif_file(self, structure: Structure) -> List[Dict]: + """ + 统计结构中各原子的occupation情况 + """ + occupation_dict = defaultdict(list) + split_dict = {} + + for i, site in enumerate(structure): + occu = self._get_occu(site.species_string) + + if occu != 1.0: + if site.species.chemical_system not in self.explict_element: + occupation_dict[occu].append(i + 1) + + # 提取元素名称列表 + elements = [] + if ':' in site.species_string: + parts = site.species_string.split(',') + for part in parts: + element_with_valence = part.strip().split(':')[0].strip() + element_match = re.match(r'([A-Z][a-z]?)', element_with_valence) + if element_match: + elements.append(element_match.group(1)) + else: + element_match = re.match(r'([A-Z][a-z]?)', site.species_string) + if element_match: + elements = [element_match.group(1)] + + split_dict[occu] = elements + + # 转换为列表格式 + occupation_list = [ + { + "occupation": occu, + "atom_serial": serials, + "numerator": None, + "denominator": None, + "split": split_dict.get(occu, []) + } + for occu, serials in occupation_dict.items() + ] + + return occupation_list + + def _get_occu(self, s_str: str) -> float: + """从物种字符串获取占据率""" + if not s_str.strip(): + return 1.0 + + pattern = r'([A-Za-z0-9+-]+):([0-9.]+)' + matches = re.findall(pattern, s_str) + + for species, occu in matches: + if species not in self.explict_element: + try: + return float(occu) + except ValueError: + continue + + return 1.0 + + def _calculate_expansion_factor(self, occupation_list: List[Dict]) -> Tuple[int, List[Dict]]: + """计算扩胞因子""" + if not occupation_list: + return 1, [] + + precision_limits = { + 'high': None, + 'normal': 100, + 'low': 10, + 'very_low': 5 + } + + limit = precision_limits.get(self.calculate_type) + + for entry in occupation_list: + occu = entry["occupation"] + if limit: + fraction = Fraction(occu).limit_denominator(limit) + else: + fraction = Fraction(occu).limit_denominator() + + entry["numerator"] = fraction.numerator + entry["denominator"] = fraction.denominator + + # 计算最小公倍数 + denominators = [entry["denominator"] for entry in occupation_list] + lcm = reduce(lambda a, b: a * b // math.gcd(a, b), denominators, 1) + + # 统一分母 + for entry in occupation_list: + denominator = entry["denominator"] + entry["numerator"] = entry["numerator"] * (lcm // denominator) + entry["denominator"] = lcm + + return lcm, occupation_list + + def _expand_and_save( + self, + structure: Structure, + occupation_list: List[Dict], + base_name: str, + output_dir: str + ) -> List[str]: + """扩胞并保存""" + lcm, oc_list = self._calculate_expansion_factor(occupation_list) + + if lcm > self.max_expansion_factor: + raise ValueError(f"扩胞因子 {lcm} 超过最大限制 {self.max_expansion_factor}") + + # 获取扩胞策略 + strategies = self._strategy_divide(structure, lcm) + + if not strategies: + raise ValueError("无法找到合适的扩胞策略") + + # 生成结构列表 + st_list = self._generate_structure_list(structure, oc_list) + + output_files = [] + keep_number = min(self.keep_number, len(strategies)) + + for index in range(keep_number): + merged = self._merge_structures(st_list, strategies[index]) + + # 添加化合价 + self._add_oxidation_states(merged) + + # 当只保存1个时不加后缀 + if keep_number == 1: + output_filename = f"{base_name}.cif" + else: + suffix = "x{}y{}z{}".format( + strategies[index]["x"], + strategies[index]["y"], + strategies[index]["z"] + ) + output_filename = f"{base_name}-{suffix}.cif" + + output_path = os.path.join(output_dir, output_filename) + merged.to(filename=output_path, fmt="cif") + output_files.append(output_path) + + return output_files + + def _add_oxidation_states(self, structure: Structure): + """添加化合价""" + # 检查是否已有化合价 + has_oxidation = all( + all(isinstance(sp, Specie) for sp in site.species.keys()) + for site in structure.sites + ) + + if not has_oxidation and self.valences: + structure.add_oxidation_state_by_element(self.valences) + + def _strategy_divide(self, structure: Structure, total: int) -> List[Dict]: + """根据晶体类型确定扩胞策略""" + try: + space_group_info = structure.get_space_group_info() + space_group_symbol = space_group_info[0] + + # 获取空间群类型 + all_spacegroup_symbols = [spglib.get_spacegroup_type(i) for i in range(1, 531)] + symbol = all_spacegroup_symbols[0] + for symbol_i in all_spacegroup_symbols: + if space_group_symbol == symbol_i.international_short: + symbol = symbol_i + break + + space_type = self._typejudge(symbol.number) + + if space_type == "Cubic": + return self._factorize_to_three_factors(total, "xyz") + else: + return self._factorize_to_three_factors(total) + except: + return self._factorize_to_three_factors(total) + + def _typejudge(self, number: int) -> str: + """判断晶体类型""" + if number in [1, 2]: + return "Triclinic" + elif 3 <= number <= 15: + return "Monoclinic" + elif 16 <= number <= 74: + return "Orthorhombic" + elif 75 <= number <= 142: + return "Tetragonal" + elif 143 <= number <= 167: + return "Trigonal" + elif 168 <= number <= 194: + return "Hexagonal" + elif 195 <= number <= 230: + return "Cubic" + else: + return "Unknown" + + def _factorize_to_three_factors(self, n: int, type_sym: str = None) -> List[Dict]: + """分解为三个因子""" + factors = [] + + if type_sym == "xyz": + for x in range(1, n + 1): + if n % x == 0: + remaining_n = n // x + for y in range(1, remaining_n + 1): + if remaining_n % y == 0 and y <= x: + z = remaining_n // y + if z <= y: + factors.append({'x': x, 'y': y, 'z': z}) + else: + for x in range(1, n + 1): + if n % x == 0: + remaining_n = n // x + for y in range(1, remaining_n + 1): + if remaining_n % y == 0: + z = remaining_n // y + factors.append({'x': x, 'y': y, 'z': z}) + + # 排序 + def sort_key(item): + return (item['x'] + item['y'] + item['z'], item['z'], item['y'], item['x']) + + return sorted(factors, key=sort_key) + + def _generate_structure_list( + self, + base_structure: Structure, + occupation_list: List[Dict] + ) -> List[Structure]: + """生成结构列表""" + if not occupation_list: + return [base_structure.copy()] + + lcm = occupation_list[0]["denominator"] + structure_list = [base_structure.copy() for _ in range(lcm)] + + for entry in occupation_list: + numerator = entry["numerator"] + denominator = entry["denominator"] + atom_indices = entry["atom_serial"] + + for atom_idx in atom_indices: + occupancy_dict = self._mark_atoms_randomly(numerator, denominator) + original_site = base_structure.sites[atom_idx - 1] + element = self._get_first_non_explicit_element(original_site.species_string) + + for copy_idx, occupy in occupancy_dict.items(): + structure_list[copy_idx].remove_sites([atom_idx - 1]) + oxi_state = self._extract_oxi_state(original_site.species_string, element) + + if len(entry["split"]) == 1: + if occupy: + new_site = PeriodicSite( + species=Species(element, oxi_state), + coords=original_site.frac_coords, + lattice=structure_list[copy_idx].lattice, + to_unit_cell=True, + label=original_site.label + ) + else: + species_dict = {Species(self.target_cation, 1.0): 0.0} + new_site = PeriodicSite( + species=species_dict, + coords=original_site.frac_coords, + lattice=structure_list[copy_idx].lattice, + to_unit_cell=True, + label=original_site.label + ) + else: + if occupy: + new_site = PeriodicSite( + species=Species(element, oxi_state), + coords=original_site.frac_coords, + lattice=structure_list[copy_idx].lattice, + to_unit_cell=True, + label=original_site.label + ) + else: + new_site = PeriodicSite( + species=Species(entry['split'][1], oxi_state), + coords=original_site.frac_coords, + lattice=structure_list[copy_idx].lattice, + to_unit_cell=True, + label=original_site.label + ) + + structure_list[copy_idx].sites.insert(atom_idx - 1, new_site) + + return structure_list + + def _mark_atoms_randomly(self, numerator: int, denominator: int) -> Dict[int, int]: + """随机标记原子""" + if numerator > denominator: + raise ValueError(f"numerator ({numerator}) 不能超过 denominator ({denominator})") + + atom_dice = list(range(denominator)) + selected_atoms = random.sample(atom_dice, numerator) + + return {atom: 1 if atom in selected_atoms else 0 for atom in atom_dice} + + def _get_first_non_explicit_element(self, species_str: str) -> str: + """获取第一个非目标元素""" + if not species_str.strip(): + return "" + + species_parts = [part.strip() for part in species_str.split(",") if part.strip()] + + for part in species_parts: + element_with_charge = part.split(":")[0].strip() + pure_element = ''.join([c for c in element_with_charge if c.isalpha()]) + + if pure_element not in self.explict_element: + return pure_element + + return "" + + def _extract_oxi_state(self, species_str: str, element: str) -> int: + """提取氧化态""" + species_parts = [part.strip() for part in species_str.split(",") if part.strip()] + + for part in species_parts: + element_with_charge = part.split(":")[0].strip() + + if element in element_with_charge: + charge_part = element_with_charge[len(element):] + + if not any(c.isdigit() for c in charge_part): + if "+" in charge_part: + return 1 + elif "-" in charge_part: + return -1 + else: + return 0 + + sign = 1 + if "-" in charge_part: + sign = -1 + + digits = "" + for c in charge_part: + if c.isdigit(): + digits += c + + if digits: + return sign * int(digits) + + return 0 + + def _merge_structures(self, structure_list: List[Structure], merge_dict: Dict) -> Structure: + """合并结构""" + if not structure_list: + raise ValueError("结构列表不能为空") + + ref_lattice = structure_list[0].lattice + + total_merge = merge_dict.get("x", 1) * merge_dict.get("y", 1) * merge_dict.get("z", 1) + if len(structure_list) != total_merge: + raise ValueError(f"结构数量({len(structure_list)})与合并次数({total_merge})不匹配") + + a, b, c = ref_lattice.abc + alpha, beta, gamma = ref_lattice.angles + + new_a = a * merge_dict.get("x", 1) + new_b = b * merge_dict.get("y", 1) + new_c = c * merge_dict.get("z", 1) + new_lattice = Lattice.from_parameters(new_a, new_b, new_c, alpha, beta, gamma) + + all_sites = [] + for i, structure in enumerate(structure_list): + x_offset = (i // (merge_dict.get("y", 1) * merge_dict.get("z", 1))) % merge_dict.get("x", 1) + y_offset = (i // merge_dict.get("z", 1)) % merge_dict.get("y", 1) + z_offset = i % merge_dict.get("z", 1) + + for site in structure: + coords = site.frac_coords.copy() + coords[0] = (coords[0] + x_offset) / merge_dict.get("x", 1) + coords[1] = (coords[1] + y_offset) / merge_dict.get("y", 1) + coords[2] = (coords[2] + z_offset) / merge_dict.get("z", 1) + all_sites.append({"species": site.species, "coords": coords}) + + return Structure( + new_lattice, + [site["species"] for site in all_sites], + [site["coords"] for site in all_sites] + ) + + +def process_batch( + input_files: List[str], + output_dir: str, + needs_expansion_flags: List[bool] = None, + valence_yaml_path: str = None, + calculate_type: str = 'low', + target_cation: str = "Li", + show_progress: bool = True +) -> List[ProcessingResult]: + """ + 批量处理CIF文件 + + Args: + input_files: 输入文件列表 + output_dir: 输出目录 + needs_expansion_flags: 是否需要扩胞的标记列表 + valence_yaml_path: 化合价配置文件路径 + calculate_type: 扩胞计算精度 + target_cation: 目标阳离子 + show_progress: 是否显示进度 + + Returns: + 处理结果列表 + """ + os.makedirs(output_dir, exist_ok=True) + + processor = StructureProcessor( + valence_yaml_path=valence_yaml_path, + calculate_type=calculate_type, + target_cation=target_cation + ) + + if needs_expansion_flags is None: + needs_expansion_flags = [False] * len(input_files) + + results = [] + total = len(input_files) + + for i, (input_file, needs_exp) in enumerate(zip(input_files, needs_expansion_flags)): + if show_progress: + print(f"\r处理进度: {i+1}/{total} - {os.path.basename(input_file)}", end="") + + result = processor.process_file(input_file, output_dir, needs_exp) + results.append(result) + + if show_progress: + print() + + return results diff --git a/tool/Li/Br+O.yaml b/tool/Li/Br+O.yaml new file mode 100644 index 0000000..02f3d68 --- /dev/null +++ b/tool/Li/Br+O.yaml @@ -0,0 +1,5 @@ +SPECIE: Li+ +ANION: Br +PERCO_R: 0.45 +NEIGHBOR: 1.8 +LONG: 2.2 diff --git a/tool/Li/Cl+O.yaml b/tool/Li/Cl+O.yaml new file mode 100644 index 0000000..a7cb541 --- /dev/null +++ b/tool/Li/Cl+O.yaml @@ -0,0 +1,5 @@ +SPECIE: Li+ +ANION: Cl +PERCO_R: 0.45 +NEIGHBOR: 1.8 +LONG: 2.2