预处理增加并行计算
This commit is contained in:
236
src/core/scheduler.py
Normal file
236
src/core/scheduler.py
Normal file
@@ -0,0 +1,236 @@
|
||||
"""
|
||||
并行调度器:支持本地多进程和SLURM集群调度
|
||||
"""
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
from typing import List, Callable, Any, Optional, Dict
|
||||
from multiprocessing import Pool, cpu_count
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
import math
|
||||
|
||||
from .progress import ProgressManager
|
||||
|
||||
|
||||
class ExecutionMode(Enum):
|
||||
"""执行模式"""
|
||||
LOCAL_SINGLE = "local_single" # 单线程
|
||||
LOCAL_MULTI = "local_multi" # 本地多进程
|
||||
SLURM_SINGLE = "slurm_single" # SLURM单节点
|
||||
SLURM_ARRAY = "slurm_array" # SLURM作业数组
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResourceConfig:
|
||||
"""资源配置"""
|
||||
max_cores: int = 4 # 最大可用核数
|
||||
cores_per_worker: int = 1 # 每个worker使用的核数
|
||||
memory_per_core: str = "4G" # 每核内存
|
||||
partition: str = "cpu" # SLURM分区
|
||||
time_limit: str = "7-00:00:00" # 时间限制
|
||||
|
||||
@property
|
||||
def num_workers(self) -> int:
|
||||
"""计算worker数量"""
|
||||
return max(1, self.max_cores // self.cores_per_worker)
|
||||
|
||||
|
||||
class ParallelScheduler:
|
||||
"""并行调度器"""
|
||||
|
||||
# 根据任务复杂度推荐的核数配置
|
||||
COMPLEXITY_CORES = {
|
||||
'low': 1, # 简单IO操作
|
||||
'medium': 2, # 结构解析+基础检查
|
||||
'high': 4, # 复杂计算(扩胞等)
|
||||
}
|
||||
|
||||
def __init__(self, resource_config: ResourceConfig = None):
|
||||
self.config = resource_config or ResourceConfig()
|
||||
self.progress_manager = None
|
||||
|
||||
@staticmethod
|
||||
def detect_environment() -> Dict[str, Any]:
|
||||
"""检测运行环境"""
|
||||
env_info = {
|
||||
'hostname': os.uname().nodename,
|
||||
'total_cores': cpu_count(),
|
||||
'has_slurm': False,
|
||||
'slurm_partitions': [],
|
||||
'available_nodes': 0,
|
||||
}
|
||||
|
||||
# 检测SLURM
|
||||
try:
|
||||
result = subprocess.run(
|
||||
['sinfo', '-h', '-o', '%P %a %c %D'],
|
||||
capture_output=True, text=True, timeout=5
|
||||
)
|
||||
if result.returncode == 0:
|
||||
env_info['has_slurm'] = True
|
||||
lines = result.stdout.strip().split('\n')
|
||||
for line in lines:
|
||||
parts = line.split()
|
||||
if len(parts) >= 4:
|
||||
partition = parts[0].rstrip('*')
|
||||
avail = parts[1]
|
||||
cores = int(parts[2])
|
||||
nodes = int(parts[3])
|
||||
if avail == 'up':
|
||||
env_info['slurm_partitions'].append({
|
||||
'name': partition,
|
||||
'cores_per_node': cores,
|
||||
'nodes': nodes
|
||||
})
|
||||
env_info['available_nodes'] += nodes
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return env_info
|
||||
|
||||
@staticmethod
|
||||
def recommend_config(
|
||||
num_tasks: int,
|
||||
task_complexity: str = 'medium',
|
||||
max_cores: int = None
|
||||
) -> ResourceConfig:
|
||||
"""根据任务量和复杂度推荐配置"""
|
||||
|
||||
env = ParallelScheduler.detect_environment()
|
||||
|
||||
# 默认最大核数
|
||||
if max_cores is None:
|
||||
max_cores = min(env['total_cores'], 32) # 最多32核
|
||||
|
||||
# 每个worker的核数
|
||||
cores_per_worker = ParallelScheduler.COMPLEXITY_CORES.get(task_complexity, 2)
|
||||
|
||||
# 计算最优worker数
|
||||
# 原则:worker数 = min(任务数, 可用核数/每worker核数)
|
||||
max_workers = max_cores // cores_per_worker
|
||||
optimal_workers = min(num_tasks, max_workers)
|
||||
|
||||
# 重新分配核数
|
||||
actual_cores = optimal_workers * cores_per_worker
|
||||
|
||||
config = ResourceConfig(
|
||||
max_cores=actual_cores,
|
||||
cores_per_worker=cores_per_worker,
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
def run_local(
|
||||
self,
|
||||
tasks: List[Any],
|
||||
worker_func: Callable,
|
||||
desc: str = "Processing"
|
||||
) -> List[Any]:
|
||||
"""本地多进程执行"""
|
||||
|
||||
num_workers = self.config.num_workers
|
||||
total = len(tasks)
|
||||
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"并行配置:")
|
||||
print(f" 总任务数: {total}")
|
||||
print(f" Worker数: {num_workers}")
|
||||
print(f" 每Worker核数: {self.config.cores_per_worker}")
|
||||
print(f" 总使用核数: {self.config.max_cores}")
|
||||
print(f"{'=' * 60}\n")
|
||||
|
||||
# 初始化进度管理器
|
||||
self.progress_manager = ProgressManager(total, desc)
|
||||
self.progress_manager.start()
|
||||
|
||||
results = []
|
||||
|
||||
if num_workers == 1:
|
||||
# 单进程模式
|
||||
for task in tasks:
|
||||
try:
|
||||
result = worker_func(task)
|
||||
results.append(result)
|
||||
self.progress_manager.update(success=True)
|
||||
except Exception as e:
|
||||
results.append(None)
|
||||
self.progress_manager.update(success=False)
|
||||
self.progress_manager.display()
|
||||
else:
|
||||
# 多进程模式
|
||||
with Pool(processes=num_workers) as pool:
|
||||
# 使用imap_unordered获取更好的性能
|
||||
for result in pool.imap_unordered(worker_func, tasks, chunksize=10):
|
||||
if result is not None:
|
||||
results.append(result)
|
||||
self.progress_manager.update(success=True)
|
||||
else:
|
||||
self.progress_manager.update(success=False)
|
||||
self.progress_manager.display()
|
||||
|
||||
self.progress_manager.finish()
|
||||
|
||||
return results
|
||||
|
||||
def generate_slurm_script(
|
||||
self,
|
||||
tasks_file: str,
|
||||
worker_script: str,
|
||||
output_dir: str,
|
||||
job_name: str = "analysis"
|
||||
) -> str:
|
||||
"""生成SLURM作业脚本"""
|
||||
|
||||
script = f"""#!/bin/bash
|
||||
#SBATCH --job-name={job_name}
|
||||
#SBATCH --partition={self.config.partition}
|
||||
#SBATCH --nodes=1
|
||||
#SBATCH --ntasks=1
|
||||
#SBATCH --cpus-per-task={self.config.max_cores}
|
||||
#SBATCH --mem-per-cpu={self.config.memory_per_core}
|
||||
#SBATCH --time={self.config.time_limit}
|
||||
#SBATCH --output={output_dir}/slurm_%j.log
|
||||
#SBATCH --error={output_dir}/slurm_%j.err
|
||||
|
||||
# 环境设置
|
||||
source $(conda info --base)/etc/profile.d/conda.sh
|
||||
conda activate screen
|
||||
|
||||
# 设置Python路径
|
||||
cd $SLURM_SUBMIT_DIR
|
||||
export PYTHONPATH=$(pwd):$PYTHONPATH
|
||||
|
||||
# 运行分析
|
||||
python {worker_script} \\
|
||||
--tasks-file {tasks_file} \\
|
||||
--output-dir {output_dir} \\
|
||||
--num-workers {self.config.num_workers}
|
||||
|
||||
echo "Job completed at $(date)"
|
||||
"""
|
||||
return script
|
||||
|
||||
def submit_slurm_job(self, script_content: str, script_path: str = None) -> str:
|
||||
"""提交SLURM作业"""
|
||||
|
||||
if script_path is None:
|
||||
fd, script_path = tempfile.mkstemp(suffix='.sh')
|
||||
os.close(fd)
|
||||
|
||||
with open(script_path, 'w') as f:
|
||||
f.write(script_content)
|
||||
|
||||
os.chmod(script_path, 0o755)
|
||||
|
||||
result = subprocess.run(
|
||||
['sbatch', script_path],
|
||||
capture_output=True, text=True
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
job_id = result.stdout.strip().split()[-1]
|
||||
print(f"作业已提交: {job_id}")
|
||||
return job_id
|
||||
else:
|
||||
raise RuntimeError(f"提交失败: {result.stderr}")
|
||||
Reference in New Issue
Block a user