diff --git a/main.py b/main.py index 4994d7b..3220d07 100644 --- a/main.py +++ b/main.py @@ -1,71 +1,112 @@ """ -高通量筛选与扩胞项目 - 主入口 -交互式命令行界面 +高通量筛选与扩胞项目 - 主入口(支持并行) """ import os import sys -# 添加 src 到路径 sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src')) -from src.analysis.database_analyzer import DatabaseAnalyzer -from src.analysis.report_generator import ReportGenerator +from analysis.database_analyzer import DatabaseAnalyzer +from analysis.report_generator import ReportGenerator +from core.scheduler import ParallelScheduler -def get_user_input(): +def print_banner(): + print(""" +╔═══════════════════════════════════════════════════════════════════╗ +║ 高通量筛选与扩胞项目 - 数据库分析工具 v2.0 ║ +║ 支持高性能并行计算 ║ +╚═══════════════════════════════════════════════════════════════════╝ + """) + + +def detect_and_show_environment(): + """检测并显示环境信息""" + env = ParallelScheduler.detect_environment() + + print("【运行环境检测】") + print(f" 主机名: {env['hostname']}") + print(f" 本地CPU核数: {env['total_cores']}") + print(f" SLURM集群: {'✅ 可用' if env['has_slurm'] else '❌ 不可用'}") + + if env['has_slurm'] and env['slurm_partitions']: + print(f" 可用分区:") + for p in env['slurm_partitions']: + print(f" - {p['name']}: {p['nodes']}节点, {p['cores_per_node']}核/节点") + + return env + + +def get_user_input(env: dict): """获取用户输入""" - print("\n" + "=" * 70) - print(" 高通量筛选与扩胞项目 - 数据库分析工具") - print("=" * 70) - # 1. 获取数据库路径 + # 数据库路径 while True: - db_path = input("\n请输入数据库路径: ").strip() + db_path = input("\n📂 请输入数据库路径: ").strip() if os.path.exists(db_path): break print(f"❌ 路径不存在: {db_path}") - # 2. 获取目标阳离子 - cation = input("请输入目标阳离子 [默认: Li]: ").strip() or "Li" + # 目标阳离子 + cation = input("🎯 请输入目标阳离子 [默认: Li]: ").strip() or "Li" - # 3. 获取目标阴离子 - anion_input = input("请输入目标阴离子 (用逗号分隔) [默认: O,S,Cl,Br]: ").strip() - if anion_input: - anions = set(a.strip() for a in anion_input.split(',')) - else: - anions = {'O', 'S', 'Cl', 'Br'} + # 目标阴离子 + anion_input = input("🎯 请输入目标阴离子 (逗号分隔) [默认: O,S,Cl,Br]: ").strip() + anions = set(a.strip() for a in anion_input.split(',')) if anion_input else {'O', 'S', 'Cl', 'Br'} - # 4. 选择阴离子模式 - print("\n阴离子模式选择:") - print(" 1. 仅单一阴离子化合物") - print(" 2. 仅复合阴离子化合物") + # 阴离子模式 + print("\n阴离子模式:") + print(" 1. 仅单一阴离子") + print(" 2. 仅复合阴离子") print(" 3. 全部 (默认)") mode_choice = input("请选择 [1/2/3]: ").strip() + anion_mode = {'1': 'single', '2': 'mixed', '3': 'all', '': 'all'}.get(mode_choice, 'all') - mode_map = {'1': 'single', '2': 'mixed', '3': 'all', '': 'all'} - anion_mode = mode_map.get(mode_choice, 'all') + # 并行配置 + print("\n" + "─" * 50) + print("【并行计算配置】") - # 5. 并行数 - n_jobs_input = input("并行线程数 [默认: 4]: ").strip() - n_jobs = int(n_jobs_input) if n_jobs_input.isdigit() else 4 + default_cores = min(env['total_cores'], 32) + cores_input = input(f"💻 最大可用核数 [默认: {default_cores}]: ").strip() + max_cores = int(cores_input) if cores_input.isdigit() else default_cores + + print("\n任务复杂度 (影响每个Worker分配的核数):") + print(" 1. 低 (1核/Worker) - 简单IO操作") + print(" 2. 中 (2核/Worker) - 结构解析+检查 [默认]") + print(" 3. 高 (4核/Worker) - 复杂计算") + complexity_choice = input("请选择 [1/2/3]: ").strip() + complexity = {'1': 'low', '2': 'medium', '3': 'high', '': 'medium'}.get(complexity_choice, 'medium') + + # 执行模式 + use_slurm = False + if env['has_slurm']: + slurm_choice = input("\n是否使用SLURM提交作业? [y/N]: ").strip().lower() + use_slurm = slurm_choice == 'y' return { 'database_path': db_path, 'target_cation': cation, 'target_anions': anions, 'anion_mode': anion_mode, - 'n_jobs': n_jobs + 'max_cores': max_cores, + 'task_complexity': complexity, + 'use_slurm': use_slurm } def main(): """主函数""" - # 获取用户输入 - params = get_user_input() + print_banner() - print("\n" + "-" * 70) - print("开始分析数据库...") - print("-" * 70) + # 环境检测 + env = detect_and_show_environment() + + # 获取用户输入 + params = get_user_input(env) + + print("\n" + "═" * 60) + print("开始数据库分析...") + print("═" * 60) # 创建分析器 analyzer = DatabaseAnalyzer( @@ -73,30 +114,42 @@ def main(): target_cation=params['target_cation'], target_anions=params['target_anions'], anion_mode=params['anion_mode'], - n_jobs=params['n_jobs'] + max_cores=params['max_cores'], + task_complexity=params['task_complexity'] ) - # 执行分析 - report = analyzer.analyze(show_progress=True) + print(f"\n发现 {len(analyzer.cif_files)} 个CIF文件") - # 打印报告 - ReportGenerator.print_report(report, detailed=True) + if params['use_slurm']: + # SLURM模式 + output_dir = input("输出目录 [默认: ./slurm_output]: ").strip() or "./slurm_output" + job_id = analyzer.analyze_slurm(output_dir=output_dir) + print(f"\n✅ SLURM作业已提交: {job_id}") + print(f" 使用 'squeue -j {job_id}' 查看状态") + print(f" 结果将保存到: {output_dir}") + else: + # 本地模式 + report = analyzer.analyze(show_progress=True) - # 询问是否导出 - export = input("\n是否导出详细结果到CSV? [y/N]: ").strip().lower() - if export == 'y': - output_path = input("输出文件路径 [默认: analysis_report.csv]: ").strip() - output_path = output_path or "analysis_report.csv" - ReportGenerator.export_to_csv(report, output_path) + # 打印报告 + ReportGenerator.print_report(report, detailed=True) - # 询问是否继续处理 - print("\n" + "-" * 70) - proceed = input("是否继续进行预处理? [y/N]: ").strip().lower() - if proceed == 'y': - print("预处理功能将在下一阶段实现...") - # TODO: 调用预处理模块 + # 保存选项 + save_choice = input("\n是否保存报告? [y/N]: ").strip().lower() + if save_choice == 'y': + output_path = input("报告路径 [默认: analysis_report.json]: ").strip() + output_path = output_path or "analysis_report.json" + report.save(output_path) + print(f"✅ 报告已保存到: {output_path}") - print("\n分析完成!") + # CSV导出 + csv_choice = input("是否导出详细CSV? [y/N]: ").strip().lower() + if csv_choice == 'y': + csv_path = input("CSV路径 [默认: analysis_details.csv]: ").strip() + csv_path = csv_path or "analysis_details.csv" + ReportGenerator.export_to_csv(report, csv_path) + + print("\n✅ 分析完成!") if __name__ == "__main__": diff --git a/src/analysis/database_analyzer.py b/src/analysis/database_analyzer.py index 8ca776a..5008ea8 100644 --- a/src/analysis/database_analyzer.py +++ b/src/analysis/database_analyzer.py @@ -1,13 +1,16 @@ """ -数据库分析器:分析整个CIF数据库的构成和质量 +数据库分析器:支持高性能并行分析 """ import os -from dataclasses import dataclass, field +import pickle +import json +from dataclasses import dataclass, field, asdict from typing import Dict, List, Set, Optional -from concurrent.futures import ThreadPoolExecutor, as_completed -from tqdm import tqdm +from pathlib import Path from .structure_inspector import StructureInspector, StructureInfo +from .worker import analyze_single_file +from ..core.scheduler import ParallelScheduler, ResourceConfig @dataclass @@ -23,13 +26,13 @@ class DatabaseReport: # 目标元素统计 target_cation: str = "" target_anions: Set[str] = field(default_factory=set) - anion_mode: str = "" # "single", "mixed", "all" + anion_mode: str = "" # 含目标阳离子的统计 cation_containing_count: int = 0 cation_containing_ratio: float = 0.0 - # 阴离子分布 (在含目标阳离子的化合物中) + # 阴离子分布 anion_distribution: Dict[str, int] = field(default_factory=dict) anion_ratios: Dict[str, float] = field(default_factory=dict) single_anion_count: int = 0 @@ -38,11 +41,9 @@ class DatabaseReport: # 数据质量统计 with_oxidation_states: int = 0 without_oxidation_states: int = 0 - - needs_expansion_count: int = 0 # 需要扩胞的数量 - cation_partial_occupancy_count: int = 0 # 阳离子共占位 - anion_partial_occupancy_count: int = 0 # 阴离子共占位 - + needs_expansion_count: int = 0 + cation_partial_occupancy_count: int = 0 + anion_partial_occupancy_count: int = 0 binary_compound_count: int = 0 has_water_count: int = 0 has_radioactive_count: int = 0 @@ -56,17 +57,39 @@ class DatabaseReport: all_structures: List[StructureInfo] = field(default_factory=list) skip_reasons_summary: Dict[str, int] = field(default_factory=dict) + def to_dict(self) -> dict: + """转换为可序列化的字典""" + d = asdict(self) + d['target_anions'] = list(self.target_anions) + d['all_structures'] = [asdict(s) for s in self.all_structures] + return d + + def save(self, path: str): + """保存报告""" + with open(path, 'w', encoding='utf-8') as f: + json.dump(self.to_dict(), f, indent=2, ensure_ascii=False) + + @classmethod + def load(cls, path: str) -> 'DatabaseReport': + """加载报告""" + with open(path, 'r', encoding='utf-8') as f: + d = json.load(f) + d['target_anions'] = set(d['target_anions']) + d['all_structures'] = [StructureInfo(**s) for s in d['all_structures']] + return cls(**d) + class DatabaseAnalyzer: - """数据库分析器""" + """数据库分析器 - 支持高性能并行""" def __init__( - self, - database_path: str, - target_cation: str = "Li", - target_anions: Set[str] = None, - anion_mode: str = "all", # "single", "mixed", "all" - n_jobs: int = 4 + self, + database_path: str, + target_cation: str = "Li", + target_anions: Set[str] = None, + anion_mode: str = "all", + max_cores: int = 4, + task_complexity: str = "medium" ): """ 初始化分析器 @@ -75,53 +98,27 @@ class DatabaseAnalyzer: database_path: 数据库路径 target_cation: 目标阳离子 target_anions: 目标阴离子集合 - anion_mode: 阴离子模式 ("single"=仅单一, "mixed"=仅复合, "all"=全部) - n_jobs: 并行数 + anion_mode: 阴离子模式 + max_cores: 最大可用核数 + task_complexity: 任务复杂度 ('low', 'medium', 'high') """ self.database_path = database_path self.target_cation = target_cation self.target_anions = target_anions or {'O', 'S', 'Cl', 'Br'} self.anion_mode = anion_mode - self.n_jobs = n_jobs + self.max_cores = max_cores + self.task_complexity = task_complexity - self.inspector = StructureInspector( - target_cation=target_cation, - target_anions=self.target_anions + # 获取文件列表 + self.cif_files = self._get_cif_files() + + # 配置调度器 + self.resource_config = ParallelScheduler.recommend_config( + num_tasks=len(self.cif_files), + task_complexity=task_complexity, + max_cores=max_cores ) - - def analyze(self, show_progress: bool = True) -> DatabaseReport: - """ - 分析数据库 - - Args: - show_progress: 是否显示进度条 - - Returns: - DatabaseReport: 分析报告 - """ - report = DatabaseReport( - database_path=self.database_path, - target_cation=self.target_cation, - target_anions=self.target_anions, - anion_mode=self.anion_mode - ) - - # 获取所有CIF文件 - cif_files = self._get_cif_files() - report.total_files = len(cif_files) - - if report.total_files == 0: - print(f"警告: 在 {self.database_path} 中未找到CIF文件") - return report - - # 并行分析所有文件 - results = self._analyze_files(cif_files, show_progress) - report.all_structures = results - - # 统计结果 - self._compute_statistics(report) - - return report + self.scheduler = ParallelScheduler(self.resource_config) def _get_cif_files(self) -> List[str]: """获取所有CIF文件路径""" @@ -136,57 +133,111 @@ class DatabaseAnalyzer: if f.endswith('.cif'): cif_files.append(os.path.join(root, f)) - return cif_files + return sorted(cif_files) - def _analyze_files( - self, - cif_files: List[str], - show_progress: bool - ) -> List[StructureInfo]: - """并行分析文件""" - results = [] + def analyze(self, show_progress: bool = True) -> DatabaseReport: + """ + 执行并行分析 - if self.n_jobs == 1: - # 单线程 - iterator = tqdm(cif_files, desc="分析CIF文件") if show_progress else cif_files - for f in iterator: - results.append(self.inspector.inspect(f)) - else: - # 多线程 - with ThreadPoolExecutor(max_workers=self.n_jobs) as executor: - futures = {executor.submit(self.inspector.inspect, f): f for f in cif_files} + Args: + show_progress: 是否显示进度 - iterator = tqdm(as_completed(futures), total=len(futures), desc="分析CIF文件") \ - if show_progress else as_completed(futures) + Returns: + DatabaseReport: 分析报告 + """ + report = DatabaseReport( + database_path=self.database_path, + target_cation=self.target_cation, + target_anions=self.target_anions, + anion_mode=self.anion_mode, + total_files=len(self.cif_files) + ) - for future in iterator: - try: - results.append(future.result()) - except Exception as e: - print(f"分析失败: {e}") + if report.total_files == 0: + print(f"⚠️ 警告: 在 {self.database_path} 中未找到CIF文件") + return report - return results + # 准备任务 + tasks = [ + (f, self.target_cation, self.target_anions) + for f in self.cif_files + ] + + # 执行并行分析 + results = self.scheduler.run_local( + tasks=tasks, + worker_func=analyze_single_file, + desc="分析CIF文件" + ) + + # 过滤有效结果 + report.all_structures = [r for r in results if r is not None] + + # 统计结果 + self._compute_statistics(report) + + return report + + def analyze_slurm( + self, + output_dir: str, + job_name: str = "cif_analysis" + ) -> str: + """ + 提交SLURM作业进行分析 + + Args: + output_dir: 输出目录 + job_name: 作业名称 + + Returns: + 作业ID + """ + os.makedirs(output_dir, exist_ok=True) + + # 保存任务配置 + tasks_file = os.path.join(output_dir, "tasks.json") + with open(tasks_file, 'w') as f: + json.dump({ + 'files': self.cif_files, + 'target_cation': self.target_cation, + 'target_anions': list(self.target_anions), + 'anion_mode': self.anion_mode + }, f) + + # 生成SLURM脚本 + worker_script = os.path.join( + os.path.dirname(__file__), 'worker.py' + ) + script = self.scheduler.generate_slurm_script( + tasks_file=tasks_file, + worker_script=worker_script, + output_dir=output_dir, + job_name=job_name + ) + + # 保存并提交 + script_path = os.path.join(output_dir, "submit.sh") + return self.scheduler.submit_slurm_job(script, script_path) def _compute_statistics(self, report: DatabaseReport): """计算统计数据""" for info in report.all_structures: - # 有效性统计 if info.is_valid: report.valid_files += 1 else: report.invalid_files += 1 continue - # 含目标阳离子统计 if not info.contains_target_cation: continue report.cation_containing_count += 1 - # 阴离子分布 for anion in info.anion_types: - report.anion_distribution[anion] = report.anion_distribution.get(anion, 0) + 1 + report.anion_distribution[anion] = \ + report.anion_distribution.get(anion, 0) + 1 if info.anion_mode == "single": report.single_anion_count += 1 @@ -201,21 +252,18 @@ class DatabaseAnalyzer: if info.anion_mode == "none": continue - # 氧化态统计 + # 各项统计 if info.has_oxidation_states: report.with_oxidation_states += 1 else: report.without_oxidation_states += 1 - # 共占位统计 if info.needs_expansion: report.needs_expansion_count += 1 if info.cation_has_partial_occupancy: report.cation_partial_occupancy_count += 1 if info.anion_has_partial_occupancy: report.anion_partial_occupancy_count += 1 - - # 其他问题统计 if info.is_binary_compound: report.binary_compound_count += 1 if info.has_water_molecule: @@ -223,7 +271,7 @@ class DatabaseAnalyzer: if info.has_radioactive_elements: report.has_radioactive_count += 1 - # 可处理性统计 + # 可处理性 if info.can_process: if info.needs_expansion: report.needs_preprocessing += 1 @@ -231,7 +279,6 @@ class DatabaseAnalyzer: report.directly_processable += 1 else: report.cannot_process += 1 - # 统计跳过原因 if info.skip_reason: for reason in info.skip_reason.split("; "): report.skip_reasons_summary[reason] = \ @@ -239,8 +286,10 @@ class DatabaseAnalyzer: # 计算比例 if report.valid_files > 0: - report.cation_containing_ratio = report.cation_containing_count / report.valid_files + report.cation_containing_ratio = \ + report.cation_containing_count / report.valid_files if report.cation_containing_count > 0: for anion, count in report.anion_distribution.items(): - report.anion_ratios[anion] = count / report.cation_containing_count \ No newline at end of file + report.anion_ratios[anion] = \ + count / report.cation_containing_count \ No newline at end of file diff --git a/src/analysis/worker.py b/src/analysis/worker.py new file mode 100644 index 0000000..92ebe56 --- /dev/null +++ b/src/analysis/worker.py @@ -0,0 +1,120 @@ +""" +工作进程:处理单个分析任务 +设计为可以独立运行(用于SLURM作业数组) +""" +import os +import pickle +from typing import List, Tuple, Optional +from dataclasses import asdict + +from .structure_inspector import StructureInspector, StructureInfo + + +def analyze_single_file(args: Tuple[str, str, set]) -> Optional[StructureInfo]: + """ + 分析单个CIF文件(Worker函数) + + Args: + args: (file_path, target_cation, target_anions) + + Returns: + StructureInfo 或 None(如果分析失败) + """ + file_path, target_cation, target_anions = args + + try: + inspector = StructureInspector( + target_cation=target_cation, + target_anions=target_anions + ) + return inspector.inspect(file_path) + except Exception as e: + # 返回一个标记失败的结果 + return StructureInfo( + file_path=file_path, + file_name=os.path.basename(file_path), + is_valid=False, + error_message=str(e) + ) + + +def batch_analyze( + file_paths: List[str], + target_cation: str, + target_anions: set, + output_file: str = None +) -> List[StructureInfo]: + """ + 批量分析文件(用于SLURM子任务) + + Args: + file_paths: CIF文件路径列表 + target_cation: 目标阳离子 + target_anions: 目标阴离子集合 + output_file: 输出文件路径(pickle格式) + + Returns: + StructureInfo列表 + """ + results = [] + + inspector = StructureInspector( + target_cation=target_cation, + target_anions=target_anions + ) + + for file_path in file_paths: + try: + info = inspector.inspect(file_path) + results.append(info) + except Exception as e: + results.append(StructureInfo( + file_path=file_path, + file_name=os.path.basename(file_path), + is_valid=False, + error_message=str(e) + )) + + # 保存结果 + if output_file: + with open(output_file, 'wb') as f: + pickle.dump([asdict(r) for r in results], f) + + return results + + +# 用于SLURM作业数组的命令行入口 +if __name__ == "__main__": + import argparse + import json + + parser = argparse.ArgumentParser(description="CIF Analysis Worker") + parser.add_argument("--tasks-file", required=True, help="任务文件路径(JSON)") + parser.add_argument("--output-dir", required=True, help="输出目录") + parser.add_argument("--task-id", type=int, default=0, help="任务ID(用于数组作业)") + parser.add_argument("--num-workers", type=int, default=1, help="并行worker数") + + args = parser.parse_args() + + # 加载任务 + with open(args.tasks_file, 'r') as f: + task_config = json.load(f) + + file_paths = task_config['files'] + target_cation = task_config['target_cation'] + target_anions = set(task_config['target_anions']) + + # 如果是数组作业,只处理分配的部分 + if args.task_id >= 0: + chunk_size = len(file_paths) // args.num_workers + 1 + start_idx = args.task_id * chunk_size + end_idx = min(start_idx + chunk_size, len(file_paths)) + file_paths = file_paths[start_idx:end_idx] + + # 输出文件 + output_file = os.path.join(args.output_dir, f"results_{args.task_id}.pkl") + + # 执行分析 + print(f"Worker {args.task_id}: 处理 {len(file_paths)} 个文件") + results = batch_analyze(file_paths, target_cation, target_anions, output_file) + print(f"Worker {args.task_id}: 完成,结果保存到 {output_file}") \ No newline at end of file diff --git a/src/core/progress.py b/src/core/progress.py new file mode 100644 index 0000000..07f157d --- /dev/null +++ b/src/core/progress.py @@ -0,0 +1,115 @@ +""" +进度管理器:支持多进程的实时进度显示 +""" +import os +import sys +import time +from multiprocessing import Manager, Value +from typing import Optional +from datetime import datetime, timedelta + + +class ProgressManager: + """多进程安全的进度管理器""" + + def __init__(self, total: int, desc: str = "Processing"): + self.total = total + self.desc = desc + self.manager = Manager() + self.completed = self.manager.Value('i', 0) + self.failed = self.manager.Value('i', 0) + self.start_time = None + self._lock = self.manager.Lock() + + def start(self): + """开始计时""" + self.start_time = time.time() + + def update(self, success: bool = True): + """更新进度(进程安全)""" + with self._lock: + if success: + self.completed.value += 1 + else: + self.failed.value += 1 + + def get_progress(self) -> dict: + """获取当前进度""" + completed = self.completed.value + failed = self.failed.value + total_done = completed + failed + + elapsed = time.time() - self.start_time if self.start_time else 0 + + if total_done > 0: + speed = total_done / elapsed # items/sec + remaining = (self.total - total_done) / speed if speed > 0 else 0 + else: + speed = 0 + remaining = 0 + + return { + 'total': self.total, + 'completed': completed, + 'failed': failed, + 'total_done': total_done, + 'percent': total_done / self.total * 100 if self.total > 0 else 0, + 'elapsed': elapsed, + 'remaining': remaining, + 'speed': speed + } + + def display(self): + """显示进度条""" + p = self.get_progress() + + # 进度条 + bar_width = 40 + filled = int(bar_width * p['total_done'] / p['total']) if p['total'] > 0 else 0 + bar = '█' * filled + '░' * (bar_width - filled) + + # 时间格式化 + elapsed_str = str(timedelta(seconds=int(p['elapsed']))) + remaining_str = str(timedelta(seconds=int(p['remaining']))) + + # 构建显示字符串 + status = ( + f"\r{self.desc}: |{bar}| " + f"{p['total_done']}/{p['total']} ({p['percent']:.1f}%) " + f"[{elapsed_str}<{remaining_str}, {p['speed']:.1f}it/s] " + f"✓{p['completed']} ✗{p['failed']}" + ) + + sys.stdout.write(status) + sys.stdout.flush() + + def finish(self): + """完成显示""" + self.display() + print() # 换行 + + +class ProgressReporter: + """进度报告器(用于后台监控)""" + + def __init__(self, progress_file: str = ".progress"): + self.progress_file = progress_file + + def save(self, progress: dict): + """保存进度到文件""" + import json + with open(self.progress_file, 'w') as f: + json.dump(progress, f) + + def load(self) -> Optional[dict]: + """从文件加载进度""" + import json + if os.path.exists(self.progress_file): + with open(self.progress_file, 'r') as f: + return json.load(f) + return None + + def cleanup(self): + """清理进度文件""" + if os.path.exists(self.progress_file): + os.remove(self.progress_file) \ No newline at end of file diff --git a/src/core/scheduler.py b/src/core/scheduler.py new file mode 100644 index 0000000..f584a4f --- /dev/null +++ b/src/core/scheduler.py @@ -0,0 +1,236 @@ +""" +并行调度器:支持本地多进程和SLURM集群调度 +""" +import os +import subprocess +import tempfile +from typing import List, Callable, Any, Optional, Dict +from multiprocessing import Pool, cpu_count +from dataclasses import dataclass +from enum import Enum +import math + +from .progress import ProgressManager + + +class ExecutionMode(Enum): + """执行模式""" + LOCAL_SINGLE = "local_single" # 单线程 + LOCAL_MULTI = "local_multi" # 本地多进程 + SLURM_SINGLE = "slurm_single" # SLURM单节点 + SLURM_ARRAY = "slurm_array" # SLURM作业数组 + + +@dataclass +class ResourceConfig: + """资源配置""" + max_cores: int = 4 # 最大可用核数 + cores_per_worker: int = 1 # 每个worker使用的核数 + memory_per_core: str = "4G" # 每核内存 + partition: str = "cpu" # SLURM分区 + time_limit: str = "7-00:00:00" # 时间限制 + + @property + def num_workers(self) -> int: + """计算worker数量""" + return max(1, self.max_cores // self.cores_per_worker) + + +class ParallelScheduler: + """并行调度器""" + + # 根据任务复杂度推荐的核数配置 + COMPLEXITY_CORES = { + 'low': 1, # 简单IO操作 + 'medium': 2, # 结构解析+基础检查 + 'high': 4, # 复杂计算(扩胞等) + } + + def __init__(self, resource_config: ResourceConfig = None): + self.config = resource_config or ResourceConfig() + self.progress_manager = None + + @staticmethod + def detect_environment() -> Dict[str, Any]: + """检测运行环境""" + env_info = { + 'hostname': os.uname().nodename, + 'total_cores': cpu_count(), + 'has_slurm': False, + 'slurm_partitions': [], + 'available_nodes': 0, + } + + # 检测SLURM + try: + result = subprocess.run( + ['sinfo', '-h', '-o', '%P %a %c %D'], + capture_output=True, text=True, timeout=5 + ) + if result.returncode == 0: + env_info['has_slurm'] = True + lines = result.stdout.strip().split('\n') + for line in lines: + parts = line.split() + if len(parts) >= 4: + partition = parts[0].rstrip('*') + avail = parts[1] + cores = int(parts[2]) + nodes = int(parts[3]) + if avail == 'up': + env_info['slurm_partitions'].append({ + 'name': partition, + 'cores_per_node': cores, + 'nodes': nodes + }) + env_info['available_nodes'] += nodes + except Exception: + pass + + return env_info + + @staticmethod + def recommend_config( + num_tasks: int, + task_complexity: str = 'medium', + max_cores: int = None + ) -> ResourceConfig: + """根据任务量和复杂度推荐配置""" + + env = ParallelScheduler.detect_environment() + + # 默认最大核数 + if max_cores is None: + max_cores = min(env['total_cores'], 32) # 最多32核 + + # 每个worker的核数 + cores_per_worker = ParallelScheduler.COMPLEXITY_CORES.get(task_complexity, 2) + + # 计算最优worker数 + # 原则:worker数 = min(任务数, 可用核数/每worker核数) + max_workers = max_cores // cores_per_worker + optimal_workers = min(num_tasks, max_workers) + + # 重新分配核数 + actual_cores = optimal_workers * cores_per_worker + + config = ResourceConfig( + max_cores=actual_cores, + cores_per_worker=cores_per_worker, + ) + + return config + + def run_local( + self, + tasks: List[Any], + worker_func: Callable, + desc: str = "Processing" + ) -> List[Any]: + """本地多进程执行""" + + num_workers = self.config.num_workers + total = len(tasks) + + print(f"\n{'=' * 60}") + print(f"并行配置:") + print(f" 总任务数: {total}") + print(f" Worker数: {num_workers}") + print(f" 每Worker核数: {self.config.cores_per_worker}") + print(f" 总使用核数: {self.config.max_cores}") + print(f"{'=' * 60}\n") + + # 初始化进度管理器 + self.progress_manager = ProgressManager(total, desc) + self.progress_manager.start() + + results = [] + + if num_workers == 1: + # 单进程模式 + for task in tasks: + try: + result = worker_func(task) + results.append(result) + self.progress_manager.update(success=True) + except Exception as e: + results.append(None) + self.progress_manager.update(success=False) + self.progress_manager.display() + else: + # 多进程模式 + with Pool(processes=num_workers) as pool: + # 使用imap_unordered获取更好的性能 + for result in pool.imap_unordered(worker_func, tasks, chunksize=10): + if result is not None: + results.append(result) + self.progress_manager.update(success=True) + else: + self.progress_manager.update(success=False) + self.progress_manager.display() + + self.progress_manager.finish() + + return results + + def generate_slurm_script( + self, + tasks_file: str, + worker_script: str, + output_dir: str, + job_name: str = "analysis" + ) -> str: + """生成SLURM作业脚本""" + + script = f"""#!/bin/bash +#SBATCH --job-name={job_name} +#SBATCH --partition={self.config.partition} +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task={self.config.max_cores} +#SBATCH --mem-per-cpu={self.config.memory_per_core} +#SBATCH --time={self.config.time_limit} +#SBATCH --output={output_dir}/slurm_%j.log +#SBATCH --error={output_dir}/slurm_%j.err + +# 环境设置 +source $(conda info --base)/etc/profile.d/conda.sh +conda activate screen + +# 设置Python路径 +cd $SLURM_SUBMIT_DIR +export PYTHONPATH=$(pwd):$PYTHONPATH + +# 运行分析 +python {worker_script} \\ + --tasks-file {tasks_file} \\ + --output-dir {output_dir} \\ + --num-workers {self.config.num_workers} + +echo "Job completed at $(date)" +""" + return script + + def submit_slurm_job(self, script_content: str, script_path: str = None) -> str: + """提交SLURM作业""" + + if script_path is None: + fd, script_path = tempfile.mkstemp(suffix='.sh') + os.close(fd) + + with open(script_path, 'w') as f: + f.write(script_content) + + os.chmod(script_path, 0o755) + + result = subprocess.run( + ['sbatch', script_path], + capture_output=True, text=True + ) + + if result.returncode == 0: + job_id = result.stdout.strip().split()[-1] + print(f"作业已提交: {job_id}") + return job_id + else: + raise RuntimeError(f"提交失败: {result.stderr}") \ No newline at end of file