预处理增加并行计算

This commit is contained in:
2025-12-14 15:42:13 +08:00
parent c91998662a
commit ae4e7280b4
5 changed files with 720 additions and 147 deletions

236
src/core/scheduler.py Normal file
View File

@@ -0,0 +1,236 @@
"""
并行调度器支持本地多进程和SLURM集群调度
"""
import os
import subprocess
import tempfile
from typing import List, Callable, Any, Optional, Dict
from multiprocessing import Pool, cpu_count
from dataclasses import dataclass
from enum import Enum
import math
from .progress import ProgressManager
class ExecutionMode(Enum):
"""执行模式"""
LOCAL_SINGLE = "local_single" # 单线程
LOCAL_MULTI = "local_multi" # 本地多进程
SLURM_SINGLE = "slurm_single" # SLURM单节点
SLURM_ARRAY = "slurm_array" # SLURM作业数组
@dataclass
class ResourceConfig:
"""资源配置"""
max_cores: int = 4 # 最大可用核数
cores_per_worker: int = 1 # 每个worker使用的核数
memory_per_core: str = "4G" # 每核内存
partition: str = "cpu" # SLURM分区
time_limit: str = "7-00:00:00" # 时间限制
@property
def num_workers(self) -> int:
"""计算worker数量"""
return max(1, self.max_cores // self.cores_per_worker)
class ParallelScheduler:
"""并行调度器"""
# 根据任务复杂度推荐的核数配置
COMPLEXITY_CORES = {
'low': 1, # 简单IO操作
'medium': 2, # 结构解析+基础检查
'high': 4, # 复杂计算(扩胞等)
}
def __init__(self, resource_config: ResourceConfig = None):
self.config = resource_config or ResourceConfig()
self.progress_manager = None
@staticmethod
def detect_environment() -> Dict[str, Any]:
"""检测运行环境"""
env_info = {
'hostname': os.uname().nodename,
'total_cores': cpu_count(),
'has_slurm': False,
'slurm_partitions': [],
'available_nodes': 0,
}
# 检测SLURM
try:
result = subprocess.run(
['sinfo', '-h', '-o', '%P %a %c %D'],
capture_output=True, text=True, timeout=5
)
if result.returncode == 0:
env_info['has_slurm'] = True
lines = result.stdout.strip().split('\n')
for line in lines:
parts = line.split()
if len(parts) >= 4:
partition = parts[0].rstrip('*')
avail = parts[1]
cores = int(parts[2])
nodes = int(parts[3])
if avail == 'up':
env_info['slurm_partitions'].append({
'name': partition,
'cores_per_node': cores,
'nodes': nodes
})
env_info['available_nodes'] += nodes
except Exception:
pass
return env_info
@staticmethod
def recommend_config(
num_tasks: int,
task_complexity: str = 'medium',
max_cores: int = None
) -> ResourceConfig:
"""根据任务量和复杂度推荐配置"""
env = ParallelScheduler.detect_environment()
# 默认最大核数
if max_cores is None:
max_cores = min(env['total_cores'], 32) # 最多32核
# 每个worker的核数
cores_per_worker = ParallelScheduler.COMPLEXITY_CORES.get(task_complexity, 2)
# 计算最优worker数
# 原则worker数 = min(任务数, 可用核数/每worker核数)
max_workers = max_cores // cores_per_worker
optimal_workers = min(num_tasks, max_workers)
# 重新分配核数
actual_cores = optimal_workers * cores_per_worker
config = ResourceConfig(
max_cores=actual_cores,
cores_per_worker=cores_per_worker,
)
return config
def run_local(
self,
tasks: List[Any],
worker_func: Callable,
desc: str = "Processing"
) -> List[Any]:
"""本地多进程执行"""
num_workers = self.config.num_workers
total = len(tasks)
print(f"\n{'=' * 60}")
print(f"并行配置:")
print(f" 总任务数: {total}")
print(f" Worker数: {num_workers}")
print(f" 每Worker核数: {self.config.cores_per_worker}")
print(f" 总使用核数: {self.config.max_cores}")
print(f"{'=' * 60}\n")
# 初始化进度管理器
self.progress_manager = ProgressManager(total, desc)
self.progress_manager.start()
results = []
if num_workers == 1:
# 单进程模式
for task in tasks:
try:
result = worker_func(task)
results.append(result)
self.progress_manager.update(success=True)
except Exception as e:
results.append(None)
self.progress_manager.update(success=False)
self.progress_manager.display()
else:
# 多进程模式
with Pool(processes=num_workers) as pool:
# 使用imap_unordered获取更好的性能
for result in pool.imap_unordered(worker_func, tasks, chunksize=10):
if result is not None:
results.append(result)
self.progress_manager.update(success=True)
else:
self.progress_manager.update(success=False)
self.progress_manager.display()
self.progress_manager.finish()
return results
def generate_slurm_script(
self,
tasks_file: str,
worker_script: str,
output_dir: str,
job_name: str = "analysis"
) -> str:
"""生成SLURM作业脚本"""
script = f"""#!/bin/bash
#SBATCH --job-name={job_name}
#SBATCH --partition={self.config.partition}
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH --cpus-per-task={self.config.max_cores}
#SBATCH --mem-per-cpu={self.config.memory_per_core}
#SBATCH --time={self.config.time_limit}
#SBATCH --output={output_dir}/slurm_%j.log
#SBATCH --error={output_dir}/slurm_%j.err
# 环境设置
source $(conda info --base)/etc/profile.d/conda.sh
conda activate screen
# 设置Python路径
cd $SLURM_SUBMIT_DIR
export PYTHONPATH=$(pwd):$PYTHONPATH
# 运行分析
python {worker_script} \\
--tasks-file {tasks_file} \\
--output-dir {output_dir} \\
--num-workers {self.config.num_workers}
echo "Job completed at $(date)"
"""
return script
def submit_slurm_job(self, script_content: str, script_path: str = None) -> str:
"""提交SLURM作业"""
if script_path is None:
fd, script_path = tempfile.mkstemp(suffix='.sh')
os.close(fd)
with open(script_path, 'w') as f:
f.write(script_content)
os.chmod(script_path, 0o755)
result = subprocess.run(
['sbatch', script_path],
capture_output=True, text=True
)
if result.returncode == 0:
job_id = result.stdout.strip().split()[-1]
print(f"作业已提交: {job_id}")
return job_id
else:
raise RuntimeError(f"提交失败: {result.stderr}")