""" 并行调度器:支持本地多进程和SLURM集群调度 """ import os import subprocess import tempfile from typing import List, Callable, Any, Optional, Dict from multiprocessing import Pool, cpu_count from dataclasses import dataclass from enum import Enum import math from .progress import ProgressManager class ExecutionMode(Enum): """执行模式""" LOCAL_SINGLE = "local_single" LOCAL_MULTI = "local_multi" SLURM_SINGLE = "slurm_single" SLURM_ARRAY = "slurm_array" @dataclass class ResourceConfig: """资源配置""" max_cores: int = 4 cores_per_worker: int = 1 memory_per_core: str = "4G" partition: str = "cpu" time_limit: str = "7-00:00:00" conda_env_path: str = "~/anaconda3/envs/screen" # 新增:Conda环境路径 @property def num_workers(self) -> int: return max(1, self.max_cores // self.cores_per_worker) class ParallelScheduler: """并行调度器""" COMPLEXITY_CORES = { 'low': 1, 'medium': 2, 'high': 4, } def __init__(self, resource_config: ResourceConfig = None): self.config = resource_config or ResourceConfig() self.progress_manager = None @staticmethod def detect_environment() -> Dict[str, Any]: """检测运行环境""" env_info = { 'hostname': os.uname().nodename, 'total_cores': cpu_count(), 'has_slurm': False, 'slurm_partitions': [], 'available_nodes': 0, 'conda_prefix': os.environ.get('CONDA_PREFIX', ''), } # 检测SLURM try: result = subprocess.run( ['sinfo', '-h', '-o', '%P %a %c %D'], capture_output=True, text=True, timeout=5 ) if result.returncode == 0: env_info['has_slurm'] = True lines = result.stdout.strip().split('\n') for line in lines: parts = line.split() if len(parts) >= 4: partition = parts[0].rstrip('*') avail = parts[1] cores = int(parts[2]) nodes = int(parts[3]) if avail == 'up': env_info['slurm_partitions'].append({ 'name': partition, 'cores_per_node': cores, 'nodes': nodes }) env_info['available_nodes'] += nodes except Exception: pass return env_info @staticmethod def recommend_config( num_tasks: int, task_complexity: str = 'medium', max_cores: int = None, conda_env_path: str = None ) -> ResourceConfig: """根据任务量和复杂度推荐配置""" env = ParallelScheduler.detect_environment() if max_cores is None: max_cores = min(env['total_cores'], 32) cores_per_worker = ParallelScheduler.COMPLEXITY_CORES.get(task_complexity, 2) max_workers = max_cores // cores_per_worker optimal_workers = min(num_tasks, max_workers) actual_cores = optimal_workers * cores_per_worker # 自动检测Conda环境路径 if conda_env_path is None: conda_env_path = env.get('conda_prefix', '~/anaconda3/envs/screen') config = ResourceConfig( max_cores=actual_cores, cores_per_worker=cores_per_worker, conda_env_path=conda_env_path, ) return config def run_local( self, tasks: List[Any], worker_func: Callable, desc: str = "Processing" ) -> List[Any]: """本地多进程执行""" num_workers = self.config.num_workers total = len(tasks) print(f"\n{'='*60}") print(f"并行配置:") print(f" 总任务数: {total}") print(f" Worker数: {num_workers}") print(f" 每Worker核数: {self.config.cores_per_worker}") print(f" 总使用核数: {self.config.max_cores}") print(f"{'='*60}\n") self.progress_manager = ProgressManager(total, desc) self.progress_manager.start() results = [] if num_workers == 1: for task in tasks: try: result = worker_func(task) results.append(result) self.progress_manager.update(success=True) except Exception as e: results.append(None) self.progress_manager.update(success=False) self.progress_manager.display() else: with Pool(processes=num_workers) as pool: for result in pool.imap_unordered(worker_func, tasks, chunksize=10): if result is not None: results.append(result) self.progress_manager.update(success=True) else: self.progress_manager.update(success=False) self.progress_manager.display() self.progress_manager.finish() return results def generate_slurm_script( self, tasks_file: str, worker_script: str, output_dir: str, job_name: str = "analysis" ) -> str: """生成SLURM作业脚本(修复版)""" # 获取当前工作目录的绝对路径 submit_dir = os.getcwd() # 转换为绝对路径 abs_tasks_file = os.path.abspath(tasks_file) abs_worker_script = os.path.abspath(worker_script) abs_output_dir = os.path.abspath(output_dir) script = f"""#!/bin/bash #SBATCH --job-name={job_name} #SBATCH --partition={self.config.partition} #SBATCH --nodes=1 #SBATCH --ntasks=1 #SBATCH --cpus-per-task={self.config.max_cores} #SBATCH --mem-per-cpu={self.config.memory_per_core} #SBATCH --time={self.config.time_limit} #SBATCH --output={abs_output_dir}/slurm_%j.log #SBATCH --error={abs_output_dir}/slurm_%j.err # ============================================ # SLURM作业脚本 - 自动生成 # ============================================ echo "===== 作业信息 =====" echo "作业ID: $SLURM_JOB_ID" echo "节点: $SLURM_NODELIST" echo "CPU数: $SLURM_CPUS_PER_TASK" echo "开始时间: $(date)" echo "====================" # ============ 环境初始化 ============ # 关键:确保bashrc被加载 if [ -f ~/.bashrc ]; then source ~/.bashrc fi # 初始化Conda if [ -f ~/anaconda3/etc/profile.d/conda.sh ]; then source ~/anaconda3/etc/profile.d/conda.sh elif [ -f /opt/anaconda3/etc/profile.d/conda.sh ]; then source /opt/anaconda3/etc/profile.d/conda.sh else echo "错误: 找不到conda.sh" exit 1 fi # 激活环境 (使用完整路径) conda activate {self.config.conda_env_path} # 验证环境 echo "" echo "===== 环境检查 =====" echo "Conda环境: $CONDA_DEFAULT_ENV" echo "Python路径: $(which python)" echo "Python版本: $(python --version 2>&1)" python -c "import pymatgen; print(f'pymatgen版本: {{pymatgen.__version__}}')" 2>/dev/null || echo "警告: pymatgen未安装" echo "====================" echo "" # 设置工作目录 cd {submit_dir} export PYTHONPATH={submit_dir}:$PYTHONPATH echo "工作目录: $(pwd)" echo "PYTHONPATH: $PYTHONPATH" echo "" # ============ 运行分析 ============ echo "开始执行分析任务..." python {abs_worker_script} \\ --tasks-file {abs_tasks_file} \\ --output-dir {abs_output_dir} \\ --num-workers {self.config.num_workers} EXIT_CODE=$? # ============ 完成 ============ echo "" echo "===== 作业完成 =====" echo "结束时间: $(date)" echo "退出代码: $EXIT_CODE" echo "====================" exit $EXIT_CODE """ return script def submit_slurm_job(self, script_content: str, script_path: str = None) -> str: """提交SLURM作业""" if script_path is None: fd, script_path = tempfile.mkstemp(suffix='.sh') os.close(fd) with open(script_path, 'w') as f: f.write(script_content) os.chmod(script_path, 0o755) # 打印脚本内容用于调试 print(f"\n生成的SLURM脚本保存到: {script_path}") result = subprocess.run( ['sbatch', script_path], capture_output=True, text=True ) if result.returncode == 0: job_id = result.stdout.strip().split()[-1] return job_id else: raise RuntimeError(f"SLURM提交失败:\nstdout: {result.stdout}\nstderr: {result.stderr}")