291 lines
8.5 KiB
Python
291 lines
8.5 KiB
Python
"""
|
||
并行调度器:支持本地多进程和SLURM集群调度
|
||
"""
|
||
import os
|
||
import subprocess
|
||
import tempfile
|
||
from typing import List, Callable, Any, Optional, Dict
|
||
from multiprocessing import Pool, cpu_count
|
||
from dataclasses import dataclass
|
||
from enum import Enum
|
||
import math
|
||
|
||
from .progress import ProgressManager
|
||
|
||
|
||
class ExecutionMode(Enum):
|
||
"""执行模式"""
|
||
LOCAL_SINGLE = "local_single"
|
||
LOCAL_MULTI = "local_multi"
|
||
SLURM_SINGLE = "slurm_single"
|
||
SLURM_ARRAY = "slurm_array"
|
||
|
||
|
||
@dataclass
|
||
class ResourceConfig:
|
||
"""资源配置"""
|
||
max_cores: int = 4
|
||
cores_per_worker: int = 1
|
||
memory_per_core: str = "4G"
|
||
partition: str = "cpu"
|
||
time_limit: str = "7-00:00:00"
|
||
conda_env_path: str = "~/anaconda3/envs/screen" # 新增:Conda环境路径
|
||
|
||
@property
|
||
def num_workers(self) -> int:
|
||
return max(1, self.max_cores // self.cores_per_worker)
|
||
|
||
|
||
class ParallelScheduler:
|
||
"""并行调度器"""
|
||
|
||
COMPLEXITY_CORES = {
|
||
'low': 1,
|
||
'medium': 2,
|
||
'high': 4,
|
||
}
|
||
|
||
def __init__(self, resource_config: ResourceConfig = None):
|
||
self.config = resource_config or ResourceConfig()
|
||
self.progress_manager = None
|
||
|
||
@staticmethod
|
||
def detect_environment() -> Dict[str, Any]:
|
||
"""检测运行环境"""
|
||
env_info = {
|
||
'hostname': os.uname().nodename,
|
||
'total_cores': cpu_count(),
|
||
'has_slurm': False,
|
||
'slurm_partitions': [],
|
||
'available_nodes': 0,
|
||
'conda_prefix': os.environ.get('CONDA_PREFIX', ''),
|
||
}
|
||
|
||
# 检测SLURM
|
||
try:
|
||
result = subprocess.run(
|
||
['sinfo', '-h', '-o', '%P %a %c %D'],
|
||
capture_output=True, text=True, timeout=5
|
||
)
|
||
if result.returncode == 0:
|
||
env_info['has_slurm'] = True
|
||
lines = result.stdout.strip().split('\n')
|
||
for line in lines:
|
||
parts = line.split()
|
||
if len(parts) >= 4:
|
||
partition = parts[0].rstrip('*')
|
||
avail = parts[1]
|
||
cores = int(parts[2])
|
||
nodes = int(parts[3])
|
||
if avail == 'up':
|
||
env_info['slurm_partitions'].append({
|
||
'name': partition,
|
||
'cores_per_node': cores,
|
||
'nodes': nodes
|
||
})
|
||
env_info['available_nodes'] += nodes
|
||
except Exception:
|
||
pass
|
||
|
||
return env_info
|
||
|
||
@staticmethod
|
||
def recommend_config(
|
||
num_tasks: int,
|
||
task_complexity: str = 'medium',
|
||
max_cores: int = None,
|
||
conda_env_path: str = None
|
||
) -> ResourceConfig:
|
||
"""根据任务量和复杂度推荐配置"""
|
||
|
||
env = ParallelScheduler.detect_environment()
|
||
|
||
if max_cores is None:
|
||
max_cores = min(env['total_cores'], 32)
|
||
|
||
cores_per_worker = ParallelScheduler.COMPLEXITY_CORES.get(task_complexity, 2)
|
||
max_workers = max_cores // cores_per_worker
|
||
optimal_workers = min(num_tasks, max_workers)
|
||
actual_cores = optimal_workers * cores_per_worker
|
||
|
||
# 自动检测Conda环境路径
|
||
if conda_env_path is None:
|
||
conda_env_path = env.get('conda_prefix', '~/anaconda3/envs/screen')
|
||
|
||
config = ResourceConfig(
|
||
max_cores=actual_cores,
|
||
cores_per_worker=cores_per_worker,
|
||
conda_env_path=conda_env_path,
|
||
)
|
||
|
||
return config
|
||
|
||
def run_local(
|
||
self,
|
||
tasks: List[Any],
|
||
worker_func: Callable,
|
||
desc: str = "Processing"
|
||
) -> List[Any]:
|
||
"""本地多进程执行"""
|
||
|
||
num_workers = self.config.num_workers
|
||
total = len(tasks)
|
||
|
||
print(f"\n{'='*60}")
|
||
print(f"并行配置:")
|
||
print(f" 总任务数: {total}")
|
||
print(f" Worker数: {num_workers}")
|
||
print(f" 每Worker核数: {self.config.cores_per_worker}")
|
||
print(f" 总使用核数: {self.config.max_cores}")
|
||
print(f"{'='*60}\n")
|
||
|
||
self.progress_manager = ProgressManager(total, desc)
|
||
self.progress_manager.start()
|
||
|
||
results = []
|
||
|
||
if num_workers == 1:
|
||
for task in tasks:
|
||
try:
|
||
result = worker_func(task)
|
||
results.append(result)
|
||
self.progress_manager.update(success=True)
|
||
except Exception as e:
|
||
results.append(None)
|
||
self.progress_manager.update(success=False)
|
||
self.progress_manager.display()
|
||
else:
|
||
with Pool(processes=num_workers) as pool:
|
||
for result in pool.imap_unordered(worker_func, tasks, chunksize=10):
|
||
if result is not None:
|
||
results.append(result)
|
||
self.progress_manager.update(success=True)
|
||
else:
|
||
self.progress_manager.update(success=False)
|
||
self.progress_manager.display()
|
||
|
||
self.progress_manager.finish()
|
||
|
||
return results
|
||
|
||
def generate_slurm_script(
|
||
self,
|
||
tasks_file: str,
|
||
worker_script: str,
|
||
output_dir: str,
|
||
job_name: str = "analysis"
|
||
) -> str:
|
||
"""生成SLURM作业脚本(修复版)"""
|
||
|
||
# 获取当前工作目录的绝对路径
|
||
submit_dir = os.getcwd()
|
||
|
||
# 转换为绝对路径
|
||
abs_tasks_file = os.path.abspath(tasks_file)
|
||
abs_worker_script = os.path.abspath(worker_script)
|
||
abs_output_dir = os.path.abspath(output_dir)
|
||
|
||
script = f"""#!/bin/bash
|
||
#SBATCH --job-name={job_name}
|
||
#SBATCH --partition={self.config.partition}
|
||
#SBATCH --nodes=1
|
||
#SBATCH --ntasks=1
|
||
#SBATCH --cpus-per-task={self.config.max_cores}
|
||
#SBATCH --mem-per-cpu={self.config.memory_per_core}
|
||
#SBATCH --time={self.config.time_limit}
|
||
#SBATCH --output={abs_output_dir}/slurm_%j.log
|
||
#SBATCH --error={abs_output_dir}/slurm_%j.err
|
||
|
||
# ============================================
|
||
# SLURM作业脚本 - 自动生成
|
||
# ============================================
|
||
|
||
echo "===== 作业信息 ====="
|
||
echo "作业ID: $SLURM_JOB_ID"
|
||
echo "节点: $SLURM_NODELIST"
|
||
echo "CPU数: $SLURM_CPUS_PER_TASK"
|
||
echo "开始时间: $(date)"
|
||
echo "===================="
|
||
|
||
# ============ 环境初始化 ============
|
||
# 关键:确保bashrc被加载
|
||
if [ -f ~/.bashrc ]; then
|
||
source ~/.bashrc
|
||
fi
|
||
|
||
# 初始化Conda
|
||
if [ -f ~/anaconda3/etc/profile.d/conda.sh ]; then
|
||
source ~/anaconda3/etc/profile.d/conda.sh
|
||
elif [ -f /opt/anaconda3/etc/profile.d/conda.sh ]; then
|
||
source /opt/anaconda3/etc/profile.d/conda.sh
|
||
else
|
||
echo "错误: 找不到conda.sh"
|
||
exit 1
|
||
fi
|
||
|
||
# 激活环境 (使用完整路径)
|
||
conda activate {self.config.conda_env_path}
|
||
|
||
# 验证环境
|
||
echo ""
|
||
echo "===== 环境检查 ====="
|
||
echo "Conda环境: $CONDA_DEFAULT_ENV"
|
||
echo "Python路径: $(which python)"
|
||
echo "Python版本: $(python --version 2>&1)"
|
||
python -c "import pymatgen; print(f'pymatgen版本: {{pymatgen.__version__}}')" 2>/dev/null || echo "警告: pymatgen未安装"
|
||
echo "===================="
|
||
echo ""
|
||
|
||
# 设置工作目录
|
||
cd {submit_dir}
|
||
export PYTHONPATH={submit_dir}:$PYTHONPATH
|
||
|
||
echo "工作目录: $(pwd)"
|
||
echo "PYTHONPATH: $PYTHONPATH"
|
||
echo ""
|
||
|
||
# ============ 运行分析 ============
|
||
echo "开始执行分析任务..."
|
||
python {abs_worker_script} \\
|
||
--tasks-file {abs_tasks_file} \\
|
||
--output-dir {abs_output_dir} \\
|
||
--num-workers {self.config.num_workers}
|
||
|
||
EXIT_CODE=$?
|
||
|
||
# ============ 完成 ============
|
||
echo ""
|
||
echo "===== 作业完成 ====="
|
||
echo "结束时间: $(date)"
|
||
echo "退出代码: $EXIT_CODE"
|
||
echo "===================="
|
||
|
||
exit $EXIT_CODE
|
||
"""
|
||
return script
|
||
|
||
def submit_slurm_job(self, script_content: str, script_path: str = None) -> str:
|
||
"""提交SLURM作业"""
|
||
|
||
if script_path is None:
|
||
fd, script_path = tempfile.mkstemp(suffix='.sh')
|
||
os.close(fd)
|
||
|
||
with open(script_path, 'w') as f:
|
||
f.write(script_content)
|
||
|
||
os.chmod(script_path, 0o755)
|
||
|
||
# 打印脚本内容用于调试
|
||
print(f"\n生成的SLURM脚本保存到: {script_path}")
|
||
|
||
result = subprocess.run(
|
||
['sbatch', script_path],
|
||
capture_output=True, text=True
|
||
)
|
||
|
||
if result.returncode == 0:
|
||
job_id = result.stdout.strip().split()[-1]
|
||
return job_id
|
||
else:
|
||
raise RuntimeError(f"SLURM提交失败:\nstdout: {result.stdout}\nstderr: {result.stderr}") |