""" 高通量筛选与扩胞项目 - 主入口(支持断点续做) """ import os import sys import json sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src')) from src.analysis.database_analyzer import DatabaseAnalyzer from src.analysis.report_generator import ReportGenerator from src.core.executor import TaskExecutor from src.preprocessing.processor import StructureProcessor from src.computation.workspace_manager import WorkspaceManager from src.computation.zeo_executor import ZeoExecutor, ZeoConfig from src.computation.result_processor import ResultProcessor, FilterCriteria def print_banner(): print(""" ╔═══════════════════════════════════════════════════════════════════╗ ║ 高通量筛选与扩胞项目 - 数据库分析工具 v2.2 ║ ║ 支持断点续做与高性能并行计算 ║ ╚═══════════════════════════════════════════════════════════════════╝ """) def detect_and_show_environment(): """检测并显示环境信息""" env = TaskExecutor.detect_environment() print("【运行环境检测】") print(f" 主机名: {env['hostname']}") print(f" 本地CPU核数: {env['total_cores']}") print(f" SLURM集群: {'✅ 可用' if env['has_slurm'] else '❌ 不可用'}") if env['has_slurm'] and env['slurm_partitions']: print(f" 可用分区: {', '.join(env['slurm_partitions'])}") if env['conda_env']: print(f" 当前Conda: {env['conda_env']}") return env def detect_workflow_status(workspace_path: str = "workspace", target_cation: str = "Li") -> dict: """ 检测工作流程状态,确定可以从哪一步继续 Returns: dict: { 'has_processed_data': bool, # 是否有扩胞处理后的数据 'has_zeo_results': bool, # 是否有 Zeo++ 计算结果 'total_structures': int, # 总结构数 'structures_with_log': int, # 有 log.txt 的结构数 'workspace_info': object, # 工作区信息 'ws_manager': object # 工作区管理器 } """ status = { 'has_processed_data': False, 'has_zeo_results': False, 'total_structures': 0, 'structures_with_log': 0, 'workspace_info': None, 'ws_manager': None } # 创建工作区管理器 ws_manager = WorkspaceManager( workspace_path=workspace_path, tool_dir="tool", target_cation=target_cation ) status['ws_manager'] = ws_manager # 检查是否有处理后的数据 data_dir = os.path.join(workspace_path, "data") if os.path.exists(data_dir): existing = ws_manager.check_existing_workspace() if existing and existing.total_structures > 0: status['has_processed_data'] = True status['total_structures'] = existing.total_structures status['workspace_info'] = existing # 检查有多少结构有 log.txt(即已完成 Zeo++ 计算) log_count = 0 for anion_key in os.listdir(data_dir): anion_dir = os.path.join(data_dir, anion_key) if not os.path.isdir(anion_dir): continue for struct_name in os.listdir(anion_dir): struct_dir = os.path.join(anion_dir, struct_name) if os.path.isdir(struct_dir): log_path = os.path.join(struct_dir, "log.txt") if os.path.exists(log_path): log_count += 1 status['structures_with_log'] = log_count # 如果大部分结构都有 log.txt,认为 Zeo++ 计算已完成 if log_count > 0 and log_count >= status['total_structures'] * 0.5: status['has_zeo_results'] = True return status def print_workflow_status(status: dict): """打印工作流程状态""" print("\n" + "─" * 50) print("【工作流程状态检测】") print("─" * 50) if not status['has_processed_data']: print(" Step 1 (扩胞+化合价): ❌ 未完成") print(" Step 2-4 (Zeo++ 计算): ❌ 未完成") print(" Step 5 (结果筛选): ❌ 未完成") else: print(f" Step 1 (扩胞+化合价): ✅ 已完成 ({status['total_structures']} 个结构)") if status['has_zeo_results']: print(f" Step 2-4 (Zeo++ 计算): ✅ 已完成 ({status['structures_with_log']}/{status['total_structures']} 有日志)") print(" Step 5 (结果筛选): ⏳ 可执行") else: print(f" Step 2-4 (Zeo++ 计算): ⏳ 可执行 ({status['structures_with_log']}/{status['total_structures']} 有日志)") print(" Step 5 (结果筛选): ❌ 需先完成 Zeo++ 计算") print("─" * 50) def get_user_choice(status: dict) -> str: """ 根据工作流程状态获取用户选择 Returns: 'step1': 从头开始(数据库分析 + 扩胞) 'step2': 从 Zeo++ 计算开始 'step5': 直接进行结果筛选 'exit': 退出 """ print("\n请选择操作:") options = [] if status['has_zeo_results']: options.append(('5', '直接进行结果筛选 (Step 5)')) options.append(('2', '重新运行 Zeo++ 计算 (Step 2-4)')) options.append(('1', '从头开始 (数据库分析 + 扩胞)')) elif status['has_processed_data']: options.append(('2', '运行 Zeo++ 计算 (Step 2-4)')) options.append(('1', '从头开始 (数据库分析 + 扩胞)')) else: options.append(('1', '从头开始 (数据库分析 + 扩胞)')) options.append(('q', '退出')) for key, desc in options: print(f" {key}. {desc}") choice = input("\n请选择 [默认: " + options[0][0] + "]: ").strip().lower() if not choice: choice = options[0][0] if choice == 'q': return 'exit' elif choice == '5': return 'step5' elif choice == '2': return 'step2' else: return 'step1' def run_step1_database_analysis(env: dict, cation: str) -> dict: """ Step 1: 数据库分析与扩胞处理 Returns: 处理参数字典,如果用户取消则返回 None """ print("\n" + "═" * 60) print("【Step 1: 数据库分析与扩胞处理】") print("═" * 60) # 数据库路径 while True: db_path = input("\n📂 请输入数据库路径: ").strip() if os.path.exists(db_path): break print(f"❌ 路径不存在: {db_path}") # 检测当前Conda环境路径 default_conda = "/cluster/home/koko125/anaconda3/envs/screen" conda_env_path = env.get('conda_env', '') or default_conda print(f"\n检测到Conda环境: {conda_env_path}") custom_env = input(f"使用此环境? [Y/n] 或输入其他路径: ").strip() if custom_env.lower() == 'n': conda_env_path = input("请输入Conda环境完整路径: ").strip() elif custom_env and custom_env.lower() != 'y': conda_env_path = custom_env # 目标阴离子 anion_input = input("🎯 请输入目标阴离子 (逗号分隔) [默认: O,S,Cl,Br]: ").strip() anions = set(a.strip() for a in anion_input.split(',')) if anion_input else {'O', 'S', 'Cl', 'Br'} # 阴离子模式 print("\n阴离子模式:") print(" 1. 仅单一阴离子") print(" 2. 仅复合阴离子") print(" 3. 全部 (默认)") mode_choice = input("请选择 [1/2/3]: ").strip() anion_mode = {'1': 'single', '2': 'mixed', '3': 'all', '': 'all'}.get(mode_choice, 'all') # 并行配置 print("\n" + "─" * 50) print("【并行计算配置】") default_cores = min(env['total_cores'], 32) cores_input = input(f"💻 最大可用核数/Worker数 [默认: {default_cores}]: ").strip() max_workers = int(cores_input) if cores_input.isdigit() else default_cores params = { 'database_path': db_path, 'target_cation': cation, 'target_anions': anions, 'anion_mode': anion_mode, 'max_workers': max_workers, 'conda_env': conda_env_path, } print("\n" + "═" * 60) print("开始数据库分析...") print("═" * 60) # 创建分析器 analyzer = DatabaseAnalyzer( database_path=params['database_path'], target_cation=params['target_cation'], target_anions=params['target_anions'], anion_mode=params['anion_mode'], max_cores=params['max_workers'], task_complexity='medium' ) print(f"\n发现 {len(analyzer.cif_files)} 个CIF文件") # 执行分析 report = analyzer.analyze(show_progress=True) # 打印报告 ReportGenerator.print_report(report, detailed=True) # 保存选项 save_choice = input("\n是否保存报告? [y/N]: ").strip().lower() if save_choice == 'y': output_path = input("报告路径 [默认: analysis_report.json]: ").strip() output_path = output_path or "analysis_report.json" report.save(output_path) print(f"✅ 报告已保存到: {output_path}") # CSV导出 csv_choice = input("是否导出详细CSV? [y/N]: ").strip().lower() if csv_choice == 'y': csv_path = input("CSV路径 [默认: analysis_details.csv]: ").strip() csv_path = csv_path or "analysis_details.csv" ReportGenerator.export_to_csv(report, csv_path) # 生成最终数据库 process_choice = input("\n是否生成最终可用的数据库(扩胞+添加化合价)? [Y/n]: ").strip().lower() if process_choice == 'n': print("\n已跳过扩胞处理,可稍后继续") return params # 输出目录设置 print("\n输出目录设置:") flat_dir = input(" 原始格式输出目录 [默认: workspace/processed]: ").strip() flat_dir = flat_dir or "workspace/processed" analysis_dir = input(" 分析格式输出目录 [默认: workspace/data]: ").strip() analysis_dir = analysis_dir or "workspace/data" # 扩胞保存数量 keep_input = input("\n扩胞结构保存数量 [默认: 1]: ").strip() keep_number = int(keep_input) if keep_input.isdigit() and int(keep_input) > 0 else 1 # 扩胞精度选择 print("\n扩胞计算精度:") print(" 1. 高精度 (精确分数)") print(" 2. 普通精度 (分母≤100)") print(" 3. 低精度 (分母≤10) [默认]") print(" 4. 极低精度 (分母≤5)") precision_choice = input("请选择 [1/2/3/4]: ").strip() calculate_type = { '1': 'high', '2': 'normal', '3': 'low', '4': 'very_low', '': 'low' }.get(precision_choice, 'low') # 获取可处理文件 processable = report.get_processable_files(include_needs_expansion=True) if not processable: print("⚠️ 没有可处理的文件") return params print(f"\n发现 {len(processable)} 个可处理的文件") # 准备文件列表和扩胞标记 input_files = [info.file_path for info in processable] needs_expansion_flags = [info.needs_expansion for info in processable] anion_types_list = [info.anion_types for info in processable] direct_count = sum(1 for f in needs_expansion_flags if not f) expansion_count = sum(1 for f in needs_expansion_flags if f) print(f" - 可直接处理: {direct_count}") print(f" - 需要扩胞: {expansion_count}") print(f" - 扩胞保存数: {keep_number}") confirm = input("\n确认开始处理? [Y/n]: ").strip().lower() if confirm == 'n': print("\n已取消处理") return params print("\n" + "═" * 60) print("开始处理结构文件...") print("═" * 60) # 创建处理器 processor = StructureProcessor( calculate_type=calculate_type, keep_number=keep_number, target_cation=params['target_cation'] ) # 创建输出目录 os.makedirs(flat_dir, exist_ok=True) os.makedirs(analysis_dir, exist_ok=True) results = [] total = len(input_files) import shutil for i, (input_file, needs_exp, anion_types) in enumerate( zip(input_files, needs_expansion_flags, anion_types_list) ): print(f"\r处理进度: {i+1}/{total} - {os.path.basename(input_file)}", end="") # 处理文件到原始格式目录 result = processor.process_file(input_file, flat_dir, needs_exp) results.append(result) if result.success: # 同时保存到分析格式目录 # 按阴离子类型创建子目录 anion_key = '+'.join(sorted(anion_types)) if anion_types else 'other' anion_dir = os.path.join(analysis_dir, anion_key) os.makedirs(anion_dir, exist_ok=True) # 获取基础文件名 base_name = os.path.splitext(os.path.basename(input_file))[0] # 创建以文件名命名的子目录 file_dir = os.path.join(anion_dir, base_name) os.makedirs(file_dir, exist_ok=True) # 复制生成的文件到分析格式目录 for output_file in result.output_files: dst_path = os.path.join(file_dir, os.path.basename(output_file)) shutil.copy2(output_file, dst_path) print() # 统计结果 success_count = sum(1 for r in results if r.success) fail_count = sum(1 for r in results if not r.success) total_output = sum(len(r.output_files) for r in results if r.success) print("\n" + "-" * 60) print("【处理结果统计】") print("-" * 60) print(f" 成功处理: {success_count}") print(f" 处理失败: {fail_count}") print(f" 生成文件: {total_output}") print(f"\n 原始格式目录: {flat_dir}") print(f" 分析格式目录: {analysis_dir}") print(f" └── 结构: data/阴离子类型/文件名/文件名.cif") # 显示失败的文件 if fail_count > 0: print("\n失败的文件:") for r in results: if not r.success: print(f" - {os.path.basename(r.input_file)}: {r.error_message}") print("\n✅ Step 1 完成!") return params def run_step2_zeo_analysis(params: dict, ws_manager: WorkspaceManager = None, workspace_info = None): """ Step 2-4: Zeo++ Voronoi 分析 Args: params: 参数字典,包含 target_cation 等 ws_manager: 工作区管理器(可选,如果已有则直接使用) workspace_info: 工作区信息(可选,如果已有则直接使用) """ print("\n" + "═" * 60) print("【Step 2-4: Zeo++ Voronoi 分析】") print("═" * 60) # 如果没有传入 ws_manager,则创建新的 if ws_manager is None: # 工作区路径 workspace_path = input("\n工作区路径 [默认: workspace]: ").strip() or "workspace" tool_dir = input("工具目录路径 [默认: tool]: ").strip() or "tool" # 创建工作区管理器 ws_manager = WorkspaceManager( workspace_path=workspace_path, tool_dir=tool_dir, target_cation=params.get('target_cation', 'Li') ) # 如果没有传入 workspace_info,则检查现有工作区 if workspace_info is None: existing = ws_manager.check_existing_workspace() if existing and existing.total_structures > 0: ws_manager.print_workspace_summary(existing) workspace_info = existing else: print("⚠️ 工作区数据目录不存在或为空") print(f" 请先运行 Step 1 生成处理后的数据到: {ws_manager.data_dir}") return # 检查是否需要创建软链接 if workspace_info.linked_structures < workspace_info.total_structures: print("\n正在创建配置文件软链接...") workspace_info = ws_manager.setup_workspace(force_relink=False) # 获取计算任务 tasks = ws_manager.get_computation_tasks(workspace_info) if not tasks: print("⚠️ 没有找到可计算的任务") return print(f"\n发现 {len(tasks)} 个计算任务") # Zeo++ 环境配置 print("\n" + "-" * 50) print("【Zeo++ 计算配置】") default_zeo_env = "/cluster/home/koko125/anaconda3/envs/zeo" zeo_env = input(f"Zeo++ Conda环境 [默认: {default_zeo_env}]: ").strip() zeo_env = zeo_env or default_zeo_env # SLURM 配置 partition = input("SLURM分区 [默认: cpu]: ").strip() or "cpu" max_concurrent_input = input("最大并发任务数 [默认: 50]: ").strip() max_concurrent = int(max_concurrent_input) if max_concurrent_input.isdigit() else 50 time_limit = input("单任务时间限制 [默认: 2:00:00]: ").strip() or "2:00:00" # 创建配置 zeo_config = ZeoConfig( conda_env=zeo_env, partition=partition, max_concurrent=max_concurrent, time_limit=time_limit ) # 确认执行 print("\n" + "-" * 50) print("【计算任务确认】") print(f" 总任务数: {len(tasks)}") print(f" Conda环境: {zeo_config.conda_env}") print(f" SLURM分区: {zeo_config.partition}") print(f" 最大并发: {zeo_config.max_concurrent}") print(f" 时间限制: {zeo_config.time_limit}") confirm = input("\n确认提交计算任务? [Y/n]: ").strip().lower() if confirm == 'n': print("已取消") return # 创建执行器并运行 executor = ZeoExecutor(zeo_config) log_dir = os.path.join(ws_manager.workspace_path, "slurm_logs") results = executor.run_batch(tasks, output_dir=log_dir) # 打印结果摘要 executor.print_results_summary(results) # 保存结果 results_file = os.path.join(ws_manager.workspace_path, "zeo_results.json") results_data = [ { 'task_id': r.task_id, 'structure_name': r.structure_name, 'cif_path': r.cif_path, 'success': r.success, 'output_files': r.output_files, 'error_message': r.error_message } for r in results ] with open(results_file, 'w') as f: json.dump(results_data, f, indent=2) print(f"\n结果已保存到: {results_file}") print("\n✅ Step 2-4 完成!") def run_step5_result_processing(workspace_path: str = "workspace"): """ Step 5: 结果处理与筛选 Args: workspace_path: 工作区路径 """ print("\n" + "═" * 60) print("【Step 5: 结果处理与筛选】") print("═" * 60) # 创建结果处理器 processor = ResultProcessor(workspace_path=workspace_path) # 获取筛选条件 print("\n设置筛选条件:") print(" (直接回车使用默认值,输入 0 表示不限制)") # 最小渗透直径(默认改为 1.0) perc_input = input(" 最小渗透直径 (Å) [默认: 1.0]: ").strip() min_percolation = float(perc_input) if perc_input else 1.0 # 最小 d 值 d_input = input(" 最小 d 值 [默认: 2.0]: ").strip() min_d = float(d_input) if d_input else 2.0 # 最大节点长度 node_input = input(" 最大节点长度 (Å) [默认: 不限制]: ").strip() max_node = float(node_input) if node_input else float('inf') # 创建筛选条件 criteria = FilterCriteria( min_percolation_diameter=min_percolation, min_d_value=min_d, max_node_length=max_node ) # 执行处理 results, stats = processor.process_and_filter( criteria=criteria, save_csv=True, copy_passed=True ) # 打印摘要 processor.print_summary(results, stats) # 显示输出位置 print("\n输出文件位置:") print(f" 汇总 CSV: {workspace_path}/results/summary.csv") print(f" 分类 CSV: {workspace_path}/results/阴离子类型/阴离子类型.csv") print(f" 通过筛选: {workspace_path}/passed/阴离子类型/结构名/") print("\n✅ Step 5 完成!") def main(): """主函数""" print_banner() # 环境检测 env = detect_and_show_environment() # 询问目标阳离子 cation = input("\n🎯 请输入目标阳离子 [默认: Li]: ").strip() or "Li" # 检测工作流程状态 status = detect_workflow_status(workspace_path="workspace", target_cation=cation) print_workflow_status(status) # 获取用户选择 choice = get_user_choice(status) if choice == 'exit': print("\n👋 再见!") return # 根据选择执行相应步骤 params = {'target_cation': cation} if choice == 'step1': # 从头开始 params = run_step1_database_analysis(env, cation) if params is None: return # 询问是否继续 continue_choice = input("\n是否继续进行 Zeo++ 计算? [Y/n]: ").strip().lower() if continue_choice == 'n': print("\n可稍后运行程序继续 Zeo++ 计算") return # 重新检测状态 status = detect_workflow_status(workspace_path="workspace", target_cation=cation) run_step2_zeo_analysis(params, status['ws_manager'], status['workspace_info']) # 询问是否继续筛选 continue_choice = input("\n是否继续进行结果筛选? [Y/n]: ").strip().lower() if continue_choice == 'n': print("\n可稍后运行程序继续结果筛选") return run_step5_result_processing("workspace") elif choice == 'step2': # 从 Zeo++ 计算开始 run_step2_zeo_analysis(params, status['ws_manager'], status['workspace_info']) # 询问是否继续筛选 continue_choice = input("\n是否继续进行结果筛选? [Y/n]: ").strip().lower() if continue_choice == 'n': print("\n可稍后运行程序继续结果筛选") return run_step5_result_processing("workspace") elif choice == 'step5': # 直接进行结果筛选 run_step5_result_processing("workspace") print("\n✅ 全部完成!") if __name__ == "__main__": main()