Files
screen/src/core/scheduler.py
2025-12-14 15:53:11 +08:00

291 lines
8.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
并行调度器支持本地多进程和SLURM集群调度
"""
import os
import subprocess
import tempfile
from typing import List, Callable, Any, Optional, Dict
from multiprocessing import Pool, cpu_count
from dataclasses import dataclass
from enum import Enum
import math
from .progress import ProgressManager
class ExecutionMode(Enum):
"""执行模式"""
LOCAL_SINGLE = "local_single"
LOCAL_MULTI = "local_multi"
SLURM_SINGLE = "slurm_single"
SLURM_ARRAY = "slurm_array"
@dataclass
class ResourceConfig:
"""资源配置"""
max_cores: int = 4
cores_per_worker: int = 1
memory_per_core: str = "4G"
partition: str = "cpu"
time_limit: str = "7-00:00:00"
conda_env_path: str = "~/anaconda3/envs/screen" # 新增Conda环境路径
@property
def num_workers(self) -> int:
return max(1, self.max_cores // self.cores_per_worker)
class ParallelScheduler:
"""并行调度器"""
COMPLEXITY_CORES = {
'low': 1,
'medium': 2,
'high': 4,
}
def __init__(self, resource_config: ResourceConfig = None):
self.config = resource_config or ResourceConfig()
self.progress_manager = None
@staticmethod
def detect_environment() -> Dict[str, Any]:
"""检测运行环境"""
env_info = {
'hostname': os.uname().nodename,
'total_cores': cpu_count(),
'has_slurm': False,
'slurm_partitions': [],
'available_nodes': 0,
'conda_prefix': os.environ.get('CONDA_PREFIX', ''),
}
# 检测SLURM
try:
result = subprocess.run(
['sinfo', '-h', '-o', '%P %a %c %D'],
capture_output=True, text=True, timeout=5
)
if result.returncode == 0:
env_info['has_slurm'] = True
lines = result.stdout.strip().split('\n')
for line in lines:
parts = line.split()
if len(parts) >= 4:
partition = parts[0].rstrip('*')
avail = parts[1]
cores = int(parts[2])
nodes = int(parts[3])
if avail == 'up':
env_info['slurm_partitions'].append({
'name': partition,
'cores_per_node': cores,
'nodes': nodes
})
env_info['available_nodes'] += nodes
except Exception:
pass
return env_info
@staticmethod
def recommend_config(
num_tasks: int,
task_complexity: str = 'medium',
max_cores: int = None,
conda_env_path: str = None
) -> ResourceConfig:
"""根据任务量和复杂度推荐配置"""
env = ParallelScheduler.detect_environment()
if max_cores is None:
max_cores = min(env['total_cores'], 32)
cores_per_worker = ParallelScheduler.COMPLEXITY_CORES.get(task_complexity, 2)
max_workers = max_cores // cores_per_worker
optimal_workers = min(num_tasks, max_workers)
actual_cores = optimal_workers * cores_per_worker
# 自动检测Conda环境路径
if conda_env_path is None:
conda_env_path = env.get('conda_prefix', '~/anaconda3/envs/screen')
config = ResourceConfig(
max_cores=actual_cores,
cores_per_worker=cores_per_worker,
conda_env_path=conda_env_path,
)
return config
def run_local(
self,
tasks: List[Any],
worker_func: Callable,
desc: str = "Processing"
) -> List[Any]:
"""本地多进程执行"""
num_workers = self.config.num_workers
total = len(tasks)
print(f"\n{'='*60}")
print(f"并行配置:")
print(f" 总任务数: {total}")
print(f" Worker数: {num_workers}")
print(f" 每Worker核数: {self.config.cores_per_worker}")
print(f" 总使用核数: {self.config.max_cores}")
print(f"{'='*60}\n")
self.progress_manager = ProgressManager(total, desc)
self.progress_manager.start()
results = []
if num_workers == 1:
for task in tasks:
try:
result = worker_func(task)
results.append(result)
self.progress_manager.update(success=True)
except Exception as e:
results.append(None)
self.progress_manager.update(success=False)
self.progress_manager.display()
else:
with Pool(processes=num_workers) as pool:
for result in pool.imap_unordered(worker_func, tasks, chunksize=10):
if result is not None:
results.append(result)
self.progress_manager.update(success=True)
else:
self.progress_manager.update(success=False)
self.progress_manager.display()
self.progress_manager.finish()
return results
def generate_slurm_script(
self,
tasks_file: str,
worker_script: str,
output_dir: str,
job_name: str = "analysis"
) -> str:
"""生成SLURM作业脚本修复版"""
# 获取当前工作目录的绝对路径
submit_dir = os.getcwd()
# 转换为绝对路径
abs_tasks_file = os.path.abspath(tasks_file)
abs_worker_script = os.path.abspath(worker_script)
abs_output_dir = os.path.abspath(output_dir)
script = f"""#!/bin/bash
#SBATCH --job-name={job_name}
#SBATCH --partition={self.config.partition}
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH --cpus-per-task={self.config.max_cores}
#SBATCH --mem-per-cpu={self.config.memory_per_core}
#SBATCH --time={self.config.time_limit}
#SBATCH --output={abs_output_dir}/slurm_%j.log
#SBATCH --error={abs_output_dir}/slurm_%j.err
# ============================================
# SLURM作业脚本 - 自动生成
# ============================================
echo "===== 作业信息 ====="
echo "作业ID: $SLURM_JOB_ID"
echo "节点: $SLURM_NODELIST"
echo "CPU数: $SLURM_CPUS_PER_TASK"
echo "开始时间: $(date)"
echo "===================="
# ============ 环境初始化 ============
# 关键确保bashrc被加载
if [ -f ~/.bashrc ]; then
source ~/.bashrc
fi
# 初始化Conda
if [ -f ~/anaconda3/etc/profile.d/conda.sh ]; then
source ~/anaconda3/etc/profile.d/conda.sh
elif [ -f /opt/anaconda3/etc/profile.d/conda.sh ]; then
source /opt/anaconda3/etc/profile.d/conda.sh
else
echo "错误: 找不到conda.sh"
exit 1
fi
# 激活环境 (使用完整路径)
conda activate {self.config.conda_env_path}
# 验证环境
echo ""
echo "===== 环境检查 ====="
echo "Conda环境: $CONDA_DEFAULT_ENV"
echo "Python路径: $(which python)"
echo "Python版本: $(python --version 2>&1)"
python -c "import pymatgen; print(f'pymatgen版本: {{pymatgen.__version__}}')" 2>/dev/null || echo "警告: pymatgen未安装"
echo "===================="
echo ""
# 设置工作目录
cd {submit_dir}
export PYTHONPATH={submit_dir}:$PYTHONPATH
echo "工作目录: $(pwd)"
echo "PYTHONPATH: $PYTHONPATH"
echo ""
# ============ 运行分析 ============
echo "开始执行分析任务..."
python {abs_worker_script} \\
--tasks-file {abs_tasks_file} \\
--output-dir {abs_output_dir} \\
--num-workers {self.config.num_workers}
EXIT_CODE=$?
# ============ 完成 ============
echo ""
echo "===== 作业完成 ====="
echo "结束时间: $(date)"
echo "退出代码: $EXIT_CODE"
echo "===================="
exit $EXIT_CODE
"""
return script
def submit_slurm_job(self, script_content: str, script_path: str = None) -> str:
"""提交SLURM作业"""
if script_path is None:
fd, script_path = tempfile.mkstemp(suffix='.sh')
os.close(fd)
with open(script_path, 'w') as f:
f.write(script_content)
os.chmod(script_path, 0o755)
# 打印脚本内容用于调试
print(f"\n生成的SLURM脚本保存到: {script_path}")
result = subprocess.run(
['sbatch', script_path],
capture_output=True, text=True
)
if result.returncode == 0:
job_id = result.stdout.strip().split()[-1]
return job_id
else:
raise RuntimeError(f"SLURM提交失败:\nstdout: {result.stdout}\nstderr: {result.stderr}")