预处理增加并行计算

This commit is contained in:
2025-12-14 15:42:13 +08:00
parent c91998662a
commit ae4e7280b4
5 changed files with 720 additions and 147 deletions

151
main.py
View File

@@ -1,71 +1,112 @@
"""
高通量筛选与扩胞项目 - 主入口
交互式命令行界面
高通量筛选与扩胞项目 - 主入口(支持并行)
"""
import os
import sys
# 添加 src 到路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
from src.analysis.database_analyzer import DatabaseAnalyzer
from src.analysis.report_generator import ReportGenerator
from analysis.database_analyzer import DatabaseAnalyzer
from analysis.report_generator import ReportGenerator
from core.scheduler import ParallelScheduler
def get_user_input():
def print_banner():
print("""
╔═══════════════════════════════════════════════════════════════════╗
║ 高通量筛选与扩胞项目 - 数据库分析工具 v2.0 ║
║ 支持高性能并行计算 ║
╚═══════════════════════════════════════════════════════════════════╝
""")
def detect_and_show_environment():
"""检测并显示环境信息"""
env = ParallelScheduler.detect_environment()
print("【运行环境检测】")
print(f" 主机名: {env['hostname']}")
print(f" 本地CPU核数: {env['total_cores']}")
print(f" SLURM集群: {'✅ 可用' if env['has_slurm'] else '❌ 不可用'}")
if env['has_slurm'] and env['slurm_partitions']:
print(f" 可用分区:")
for p in env['slurm_partitions']:
print(f" - {p['name']}: {p['nodes']}节点, {p['cores_per_node']}核/节点")
return env
def get_user_input(env: dict):
"""获取用户输入"""
print("\n" + "=" * 70)
print(" 高通量筛选与扩胞项目 - 数据库分析工具")
print("=" * 70)
# 1. 获取数据库路径
# 数据库路径
while True:
db_path = input("\n请输入数据库路径: ").strip()
db_path = input("\n📂 请输入数据库路径: ").strip()
if os.path.exists(db_path):
break
print(f"❌ 路径不存在: {db_path}")
# 2. 获取目标阳离子
cation = input("请输入目标阳离子 [默认: Li]: ").strip() or "Li"
# 目标阳离子
cation = input("🎯 请输入目标阳离子 [默认: Li]: ").strip() or "Li"
# 3. 获取目标阴离子
anion_input = input("请输入目标阴离子 (逗号分隔) [默认: O,S,Cl,Br]: ").strip()
if anion_input:
anions = set(a.strip() for a in anion_input.split(','))
else:
anions = {'O', 'S', 'Cl', 'Br'}
# 目标阴离子
anion_input = input("🎯 请输入目标阴离子 (逗号分隔) [默认: O,S,Cl,Br]: ").strip()
anions = set(a.strip() for a in anion_input.split(',')) if anion_input else {'O', 'S', 'Cl', 'Br'}
# 4. 选择阴离子模式
print("\n阴离子模式选择:")
print(" 1. 仅单一阴离子化合物")
print(" 2. 仅复合阴离子化合物")
# 阴离子模式
print("\n阴离子模式:")
print(" 1. 仅单一阴离子")
print(" 2. 仅复合阴离子")
print(" 3. 全部 (默认)")
mode_choice = input("请选择 [1/2/3]: ").strip()
anion_mode = {'1': 'single', '2': 'mixed', '3': 'all', '': 'all'}.get(mode_choice, 'all')
mode_map = {'1': 'single', '2': 'mixed', '3': 'all', '': 'all'}
anion_mode = mode_map.get(mode_choice, 'all')
# 并行配置
print("\n" + "" * 50)
print("【并行计算配置】")
# 5. 并行数
n_jobs_input = input("并行线程数 [默认: 4]: ").strip()
n_jobs = int(n_jobs_input) if n_jobs_input.isdigit() else 4
default_cores = min(env['total_cores'], 32)
cores_input = input(f"💻 最大可用核数 [默认: {default_cores}]: ").strip()
max_cores = int(cores_input) if cores_input.isdigit() else default_cores
print("\n任务复杂度 (影响每个Worker分配的核数):")
print(" 1. 低 (1核/Worker) - 简单IO操作")
print(" 2. 中 (2核/Worker) - 结构解析+检查 [默认]")
print(" 3. 高 (4核/Worker) - 复杂计算")
complexity_choice = input("请选择 [1/2/3]: ").strip()
complexity = {'1': 'low', '2': 'medium', '3': 'high', '': 'medium'}.get(complexity_choice, 'medium')
# 执行模式
use_slurm = False
if env['has_slurm']:
slurm_choice = input("\n是否使用SLURM提交作业? [y/N]: ").strip().lower()
use_slurm = slurm_choice == 'y'
return {
'database_path': db_path,
'target_cation': cation,
'target_anions': anions,
'anion_mode': anion_mode,
'n_jobs': n_jobs
'max_cores': max_cores,
'task_complexity': complexity,
'use_slurm': use_slurm
}
def main():
"""主函数"""
# 获取用户输入
params = get_user_input()
print_banner()
print("\n" + "-" * 70)
print("开始分析数据库...")
print("-" * 70)
# 环境检测
env = detect_and_show_environment()
# 获取用户输入
params = get_user_input(env)
print("\n" + "" * 60)
print("开始数据库分析...")
print("" * 60)
# 创建分析器
analyzer = DatabaseAnalyzer(
@@ -73,30 +114,42 @@ def main():
target_cation=params['target_cation'],
target_anions=params['target_anions'],
anion_mode=params['anion_mode'],
n_jobs=params['n_jobs']
max_cores=params['max_cores'],
task_complexity=params['task_complexity']
)
# 执行分析
print(f"\n发现 {len(analyzer.cif_files)} 个CIF文件")
if params['use_slurm']:
# SLURM模式
output_dir = input("输出目录 [默认: ./slurm_output]: ").strip() or "./slurm_output"
job_id = analyzer.analyze_slurm(output_dir=output_dir)
print(f"\n✅ SLURM作业已提交: {job_id}")
print(f" 使用 'squeue -j {job_id}' 查看状态")
print(f" 结果将保存到: {output_dir}")
else:
# 本地模式
report = analyzer.analyze(show_progress=True)
# 打印报告
ReportGenerator.print_report(report, detailed=True)
# 询问是否导出
export = input("\n是否导出详细结果到CSV? [y/N]: ").strip().lower()
if export == 'y':
output_path = input("输出文件路径 [默认: analysis_report.csv]: ").strip()
output_path = output_path or "analysis_report.csv"
ReportGenerator.export_to_csv(report, output_path)
# 保存选项
save_choice = input("\n是否保存报告? [y/N]: ").strip().lower()
if save_choice == 'y':
output_path = input("报告路径 [默认: analysis_report.json]: ").strip()
output_path = output_path or "analysis_report.json"
report.save(output_path)
print(f"✅ 报告已保存到: {output_path}")
# 询问是否继续处理
print("\n" + "-" * 70)
proceed = input("是否继续进行预处理? [y/N]: ").strip().lower()
if proceed == 'y':
print("预处理功能将在下一阶段实现...")
# TODO: 调用预处理模块
# CSV导出
csv_choice = input("是否导出详细CSV? [y/N]: ").strip().lower()
if csv_choice == 'y':
csv_path = input("CSV路径 [默认: analysis_details.csv]: ").strip()
csv_path = csv_path or "analysis_details.csv"
ReportGenerator.export_to_csv(report, csv_path)
print("\n分析完成!")
print("\n分析完成!")
if __name__ == "__main__":

View File

@@ -1,13 +1,16 @@
"""
数据库分析器:分析整个CIF数据库的构成和质量
数据库分析器:支持高性能并行分析
"""
import os
from dataclasses import dataclass, field
import pickle
import json
from dataclasses import dataclass, field, asdict
from typing import Dict, List, Set, Optional
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
from pathlib import Path
from .structure_inspector import StructureInspector, StructureInfo
from .worker import analyze_single_file
from ..core.scheduler import ParallelScheduler, ResourceConfig
@dataclass
@@ -23,13 +26,13 @@ class DatabaseReport:
# 目标元素统计
target_cation: str = ""
target_anions: Set[str] = field(default_factory=set)
anion_mode: str = "" # "single", "mixed", "all"
anion_mode: str = ""
# 含目标阳离子的统计
cation_containing_count: int = 0
cation_containing_ratio: float = 0.0
# 阴离子分布 (在含目标阳离子的化合物中)
# 阴离子分布
anion_distribution: Dict[str, int] = field(default_factory=dict)
anion_ratios: Dict[str, float] = field(default_factory=dict)
single_anion_count: int = 0
@@ -38,11 +41,9 @@ class DatabaseReport:
# 数据质量统计
with_oxidation_states: int = 0
without_oxidation_states: int = 0
needs_expansion_count: int = 0 # 需要扩胞的数量
cation_partial_occupancy_count: int = 0 # 阳离子共占位
anion_partial_occupancy_count: int = 0 # 阴离子共占位
needs_expansion_count: int = 0
cation_partial_occupancy_count: int = 0
anion_partial_occupancy_count: int = 0
binary_compound_count: int = 0
has_water_count: int = 0
has_radioactive_count: int = 0
@@ -56,17 +57,39 @@ class DatabaseReport:
all_structures: List[StructureInfo] = field(default_factory=list)
skip_reasons_summary: Dict[str, int] = field(default_factory=dict)
def to_dict(self) -> dict:
"""转换为可序列化的字典"""
d = asdict(self)
d['target_anions'] = list(self.target_anions)
d['all_structures'] = [asdict(s) for s in self.all_structures]
return d
def save(self, path: str):
"""保存报告"""
with open(path, 'w', encoding='utf-8') as f:
json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
@classmethod
def load(cls, path: str) -> 'DatabaseReport':
"""加载报告"""
with open(path, 'r', encoding='utf-8') as f:
d = json.load(f)
d['target_anions'] = set(d['target_anions'])
d['all_structures'] = [StructureInfo(**s) for s in d['all_structures']]
return cls(**d)
class DatabaseAnalyzer:
"""数据库分析器"""
"""数据库分析器 - 支持高性能并行"""
def __init__(
self,
database_path: str,
target_cation: str = "Li",
target_anions: Set[str] = None,
anion_mode: str = "all", # "single", "mixed", "all"
n_jobs: int = 4
anion_mode: str = "all",
max_cores: int = 4,
task_complexity: str = "medium"
):
"""
初始化分析器
@@ -75,53 +98,27 @@ class DatabaseAnalyzer:
database_path: 数据库路径
target_cation: 目标阳离子
target_anions: 目标阴离子集合
anion_mode: 阴离子模式 ("single"=仅单一, "mixed"=仅复合, "all"=全部)
n_jobs: 并行
anion_mode: 阴离子模式
max_cores: 最大可用核
task_complexity: 任务复杂度 ('low', 'medium', 'high')
"""
self.database_path = database_path
self.target_cation = target_cation
self.target_anions = target_anions or {'O', 'S', 'Cl', 'Br'}
self.anion_mode = anion_mode
self.n_jobs = n_jobs
self.max_cores = max_cores
self.task_complexity = task_complexity
self.inspector = StructureInspector(
target_cation=target_cation,
target_anions=self.target_anions
# 获取文件列表
self.cif_files = self._get_cif_files()
# 配置调度器
self.resource_config = ParallelScheduler.recommend_config(
num_tasks=len(self.cif_files),
task_complexity=task_complexity,
max_cores=max_cores
)
def analyze(self, show_progress: bool = True) -> DatabaseReport:
"""
分析数据库
Args:
show_progress: 是否显示进度条
Returns:
DatabaseReport: 分析报告
"""
report = DatabaseReport(
database_path=self.database_path,
target_cation=self.target_cation,
target_anions=self.target_anions,
anion_mode=self.anion_mode
)
# 获取所有CIF文件
cif_files = self._get_cif_files()
report.total_files = len(cif_files)
if report.total_files == 0:
print(f"警告: 在 {self.database_path} 中未找到CIF文件")
return report
# 并行分析所有文件
results = self._analyze_files(cif_files, show_progress)
report.all_structures = results
# 统计结果
self._compute_statistics(report)
return report
self.scheduler = ParallelScheduler(self.resource_config)
def _get_cif_files(self) -> List[str]:
"""获取所有CIF文件路径"""
@@ -136,57 +133,111 @@ class DatabaseAnalyzer:
if f.endswith('.cif'):
cif_files.append(os.path.join(root, f))
return cif_files
return sorted(cif_files)
def _analyze_files(
def analyze(self, show_progress: bool = True) -> DatabaseReport:
"""
执行并行分析
Args:
show_progress: 是否显示进度
Returns:
DatabaseReport: 分析报告
"""
report = DatabaseReport(
database_path=self.database_path,
target_cation=self.target_cation,
target_anions=self.target_anions,
anion_mode=self.anion_mode,
total_files=len(self.cif_files)
)
if report.total_files == 0:
print(f"⚠️ 警告: 在 {self.database_path} 中未找到CIF文件")
return report
# 准备任务
tasks = [
(f, self.target_cation, self.target_anions)
for f in self.cif_files
]
# 执行并行分析
results = self.scheduler.run_local(
tasks=tasks,
worker_func=analyze_single_file,
desc="分析CIF文件"
)
# 过滤有效结果
report.all_structures = [r for r in results if r is not None]
# 统计结果
self._compute_statistics(report)
return report
def analyze_slurm(
self,
cif_files: List[str],
show_progress: bool
) -> List[StructureInfo]:
"""并行分析文件"""
results = []
output_dir: str,
job_name: str = "cif_analysis"
) -> str:
"""
提交SLURM作业进行分析
if self.n_jobs == 1:
# 单线程
iterator = tqdm(cif_files, desc="分析CIF文件") if show_progress else cif_files
for f in iterator:
results.append(self.inspector.inspect(f))
else:
# 多线程
with ThreadPoolExecutor(max_workers=self.n_jobs) as executor:
futures = {executor.submit(self.inspector.inspect, f): f for f in cif_files}
Args:
output_dir: 输出目录
job_name: 作业名称
iterator = tqdm(as_completed(futures), total=len(futures), desc="分析CIF文件") \
if show_progress else as_completed(futures)
Returns:
作业ID
"""
os.makedirs(output_dir, exist_ok=True)
for future in iterator:
try:
results.append(future.result())
except Exception as e:
print(f"分析失败: {e}")
# 保存任务配置
tasks_file = os.path.join(output_dir, "tasks.json")
with open(tasks_file, 'w') as f:
json.dump({
'files': self.cif_files,
'target_cation': self.target_cation,
'target_anions': list(self.target_anions),
'anion_mode': self.anion_mode
}, f)
return results
# 生成SLURM脚本
worker_script = os.path.join(
os.path.dirname(__file__), 'worker.py'
)
script = self.scheduler.generate_slurm_script(
tasks_file=tasks_file,
worker_script=worker_script,
output_dir=output_dir,
job_name=job_name
)
# 保存并提交
script_path = os.path.join(output_dir, "submit.sh")
return self.scheduler.submit_slurm_job(script, script_path)
def _compute_statistics(self, report: DatabaseReport):
"""计算统计数据"""
for info in report.all_structures:
# 有效性统计
if info.is_valid:
report.valid_files += 1
else:
report.invalid_files += 1
continue
# 含目标阳离子统计
if not info.contains_target_cation:
continue
report.cation_containing_count += 1
# 阴离子分布
for anion in info.anion_types:
report.anion_distribution[anion] = report.anion_distribution.get(anion, 0) + 1
report.anion_distribution[anion] = \
report.anion_distribution.get(anion, 0) + 1
if info.anion_mode == "single":
report.single_anion_count += 1
@@ -201,21 +252,18 @@ class DatabaseAnalyzer:
if info.anion_mode == "none":
continue
# 氧化态统计
# 各项统计
if info.has_oxidation_states:
report.with_oxidation_states += 1
else:
report.without_oxidation_states += 1
# 共占位统计
if info.needs_expansion:
report.needs_expansion_count += 1
if info.cation_has_partial_occupancy:
report.cation_partial_occupancy_count += 1
if info.anion_has_partial_occupancy:
report.anion_partial_occupancy_count += 1
# 其他问题统计
if info.is_binary_compound:
report.binary_compound_count += 1
if info.has_water_molecule:
@@ -223,7 +271,7 @@ class DatabaseAnalyzer:
if info.has_radioactive_elements:
report.has_radioactive_count += 1
# 可处理性统计
# 可处理性
if info.can_process:
if info.needs_expansion:
report.needs_preprocessing += 1
@@ -231,7 +279,6 @@ class DatabaseAnalyzer:
report.directly_processable += 1
else:
report.cannot_process += 1
# 统计跳过原因
if info.skip_reason:
for reason in info.skip_reason.split("; "):
report.skip_reasons_summary[reason] = \
@@ -239,8 +286,10 @@ class DatabaseAnalyzer:
# 计算比例
if report.valid_files > 0:
report.cation_containing_ratio = report.cation_containing_count / report.valid_files
report.cation_containing_ratio = \
report.cation_containing_count / report.valid_files
if report.cation_containing_count > 0:
for anion, count in report.anion_distribution.items():
report.anion_ratios[anion] = count / report.cation_containing_count
report.anion_ratios[anion] = \
count / report.cation_containing_count

120
src/analysis/worker.py Normal file
View File

@@ -0,0 +1,120 @@
"""
工作进程:处理单个分析任务
设计为可以独立运行用于SLURM作业数组
"""
import os
import pickle
from typing import List, Tuple, Optional
from dataclasses import asdict
from .structure_inspector import StructureInspector, StructureInfo
def analyze_single_file(args: Tuple[str, str, set]) -> Optional[StructureInfo]:
"""
分析单个CIF文件Worker函数
Args:
args: (file_path, target_cation, target_anions)
Returns:
StructureInfo 或 None如果分析失败
"""
file_path, target_cation, target_anions = args
try:
inspector = StructureInspector(
target_cation=target_cation,
target_anions=target_anions
)
return inspector.inspect(file_path)
except Exception as e:
# 返回一个标记失败的结果
return StructureInfo(
file_path=file_path,
file_name=os.path.basename(file_path),
is_valid=False,
error_message=str(e)
)
def batch_analyze(
file_paths: List[str],
target_cation: str,
target_anions: set,
output_file: str = None
) -> List[StructureInfo]:
"""
批量分析文件用于SLURM子任务
Args:
file_paths: CIF文件路径列表
target_cation: 目标阳离子
target_anions: 目标阴离子集合
output_file: 输出文件路径pickle格式
Returns:
StructureInfo列表
"""
results = []
inspector = StructureInspector(
target_cation=target_cation,
target_anions=target_anions
)
for file_path in file_paths:
try:
info = inspector.inspect(file_path)
results.append(info)
except Exception as e:
results.append(StructureInfo(
file_path=file_path,
file_name=os.path.basename(file_path),
is_valid=False,
error_message=str(e)
))
# 保存结果
if output_file:
with open(output_file, 'wb') as f:
pickle.dump([asdict(r) for r in results], f)
return results
# 用于SLURM作业数组的命令行入口
if __name__ == "__main__":
import argparse
import json
parser = argparse.ArgumentParser(description="CIF Analysis Worker")
parser.add_argument("--tasks-file", required=True, help="任务文件路径(JSON)")
parser.add_argument("--output-dir", required=True, help="输出目录")
parser.add_argument("--task-id", type=int, default=0, help="任务ID(用于数组作业)")
parser.add_argument("--num-workers", type=int, default=1, help="并行worker数")
args = parser.parse_args()
# 加载任务
with open(args.tasks_file, 'r') as f:
task_config = json.load(f)
file_paths = task_config['files']
target_cation = task_config['target_cation']
target_anions = set(task_config['target_anions'])
# 如果是数组作业,只处理分配的部分
if args.task_id >= 0:
chunk_size = len(file_paths) // args.num_workers + 1
start_idx = args.task_id * chunk_size
end_idx = min(start_idx + chunk_size, len(file_paths))
file_paths = file_paths[start_idx:end_idx]
# 输出文件
output_file = os.path.join(args.output_dir, f"results_{args.task_id}.pkl")
# 执行分析
print(f"Worker {args.task_id}: 处理 {len(file_paths)} 个文件")
results = batch_analyze(file_paths, target_cation, target_anions, output_file)
print(f"Worker {args.task_id}: 完成,结果保存到 {output_file}")

115
src/core/progress.py Normal file
View File

@@ -0,0 +1,115 @@
"""
进度管理器:支持多进程的实时进度显示
"""
import os
import sys
import time
from multiprocessing import Manager, Value
from typing import Optional
from datetime import datetime, timedelta
class ProgressManager:
"""多进程安全的进度管理器"""
def __init__(self, total: int, desc: str = "Processing"):
self.total = total
self.desc = desc
self.manager = Manager()
self.completed = self.manager.Value('i', 0)
self.failed = self.manager.Value('i', 0)
self.start_time = None
self._lock = self.manager.Lock()
def start(self):
"""开始计时"""
self.start_time = time.time()
def update(self, success: bool = True):
"""更新进度(进程安全)"""
with self._lock:
if success:
self.completed.value += 1
else:
self.failed.value += 1
def get_progress(self) -> dict:
"""获取当前进度"""
completed = self.completed.value
failed = self.failed.value
total_done = completed + failed
elapsed = time.time() - self.start_time if self.start_time else 0
if total_done > 0:
speed = total_done / elapsed # items/sec
remaining = (self.total - total_done) / speed if speed > 0 else 0
else:
speed = 0
remaining = 0
return {
'total': self.total,
'completed': completed,
'failed': failed,
'total_done': total_done,
'percent': total_done / self.total * 100 if self.total > 0 else 0,
'elapsed': elapsed,
'remaining': remaining,
'speed': speed
}
def display(self):
"""显示进度条"""
p = self.get_progress()
# 进度条
bar_width = 40
filled = int(bar_width * p['total_done'] / p['total']) if p['total'] > 0 else 0
bar = '' * filled + '' * (bar_width - filled)
# 时间格式化
elapsed_str = str(timedelta(seconds=int(p['elapsed'])))
remaining_str = str(timedelta(seconds=int(p['remaining'])))
# 构建显示字符串
status = (
f"\r{self.desc}: |{bar}| "
f"{p['total_done']}/{p['total']} ({p['percent']:.1f}%) "
f"[{elapsed_str}<{remaining_str}, {p['speed']:.1f}it/s] "
f"{p['completed']}{p['failed']}"
)
sys.stdout.write(status)
sys.stdout.flush()
def finish(self):
"""完成显示"""
self.display()
print() # 换行
class ProgressReporter:
"""进度报告器(用于后台监控)"""
def __init__(self, progress_file: str = ".progress"):
self.progress_file = progress_file
def save(self, progress: dict):
"""保存进度到文件"""
import json
with open(self.progress_file, 'w') as f:
json.dump(progress, f)
def load(self) -> Optional[dict]:
"""从文件加载进度"""
import json
if os.path.exists(self.progress_file):
with open(self.progress_file, 'r') as f:
return json.load(f)
return None
def cleanup(self):
"""清理进度文件"""
if os.path.exists(self.progress_file):
os.remove(self.progress_file)

236
src/core/scheduler.py Normal file
View File

@@ -0,0 +1,236 @@
"""
并行调度器支持本地多进程和SLURM集群调度
"""
import os
import subprocess
import tempfile
from typing import List, Callable, Any, Optional, Dict
from multiprocessing import Pool, cpu_count
from dataclasses import dataclass
from enum import Enum
import math
from .progress import ProgressManager
class ExecutionMode(Enum):
"""执行模式"""
LOCAL_SINGLE = "local_single" # 单线程
LOCAL_MULTI = "local_multi" # 本地多进程
SLURM_SINGLE = "slurm_single" # SLURM单节点
SLURM_ARRAY = "slurm_array" # SLURM作业数组
@dataclass
class ResourceConfig:
"""资源配置"""
max_cores: int = 4 # 最大可用核数
cores_per_worker: int = 1 # 每个worker使用的核数
memory_per_core: str = "4G" # 每核内存
partition: str = "cpu" # SLURM分区
time_limit: str = "7-00:00:00" # 时间限制
@property
def num_workers(self) -> int:
"""计算worker数量"""
return max(1, self.max_cores // self.cores_per_worker)
class ParallelScheduler:
"""并行调度器"""
# 根据任务复杂度推荐的核数配置
COMPLEXITY_CORES = {
'low': 1, # 简单IO操作
'medium': 2, # 结构解析+基础检查
'high': 4, # 复杂计算(扩胞等)
}
def __init__(self, resource_config: ResourceConfig = None):
self.config = resource_config or ResourceConfig()
self.progress_manager = None
@staticmethod
def detect_environment() -> Dict[str, Any]:
"""检测运行环境"""
env_info = {
'hostname': os.uname().nodename,
'total_cores': cpu_count(),
'has_slurm': False,
'slurm_partitions': [],
'available_nodes': 0,
}
# 检测SLURM
try:
result = subprocess.run(
['sinfo', '-h', '-o', '%P %a %c %D'],
capture_output=True, text=True, timeout=5
)
if result.returncode == 0:
env_info['has_slurm'] = True
lines = result.stdout.strip().split('\n')
for line in lines:
parts = line.split()
if len(parts) >= 4:
partition = parts[0].rstrip('*')
avail = parts[1]
cores = int(parts[2])
nodes = int(parts[3])
if avail == 'up':
env_info['slurm_partitions'].append({
'name': partition,
'cores_per_node': cores,
'nodes': nodes
})
env_info['available_nodes'] += nodes
except Exception:
pass
return env_info
@staticmethod
def recommend_config(
num_tasks: int,
task_complexity: str = 'medium',
max_cores: int = None
) -> ResourceConfig:
"""根据任务量和复杂度推荐配置"""
env = ParallelScheduler.detect_environment()
# 默认最大核数
if max_cores is None:
max_cores = min(env['total_cores'], 32) # 最多32核
# 每个worker的核数
cores_per_worker = ParallelScheduler.COMPLEXITY_CORES.get(task_complexity, 2)
# 计算最优worker数
# 原则worker数 = min(任务数, 可用核数/每worker核数)
max_workers = max_cores // cores_per_worker
optimal_workers = min(num_tasks, max_workers)
# 重新分配核数
actual_cores = optimal_workers * cores_per_worker
config = ResourceConfig(
max_cores=actual_cores,
cores_per_worker=cores_per_worker,
)
return config
def run_local(
self,
tasks: List[Any],
worker_func: Callable,
desc: str = "Processing"
) -> List[Any]:
"""本地多进程执行"""
num_workers = self.config.num_workers
total = len(tasks)
print(f"\n{'=' * 60}")
print(f"并行配置:")
print(f" 总任务数: {total}")
print(f" Worker数: {num_workers}")
print(f" 每Worker核数: {self.config.cores_per_worker}")
print(f" 总使用核数: {self.config.max_cores}")
print(f"{'=' * 60}\n")
# 初始化进度管理器
self.progress_manager = ProgressManager(total, desc)
self.progress_manager.start()
results = []
if num_workers == 1:
# 单进程模式
for task in tasks:
try:
result = worker_func(task)
results.append(result)
self.progress_manager.update(success=True)
except Exception as e:
results.append(None)
self.progress_manager.update(success=False)
self.progress_manager.display()
else:
# 多进程模式
with Pool(processes=num_workers) as pool:
# 使用imap_unordered获取更好的性能
for result in pool.imap_unordered(worker_func, tasks, chunksize=10):
if result is not None:
results.append(result)
self.progress_manager.update(success=True)
else:
self.progress_manager.update(success=False)
self.progress_manager.display()
self.progress_manager.finish()
return results
def generate_slurm_script(
self,
tasks_file: str,
worker_script: str,
output_dir: str,
job_name: str = "analysis"
) -> str:
"""生成SLURM作业脚本"""
script = f"""#!/bin/bash
#SBATCH --job-name={job_name}
#SBATCH --partition={self.config.partition}
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH --cpus-per-task={self.config.max_cores}
#SBATCH --mem-per-cpu={self.config.memory_per_core}
#SBATCH --time={self.config.time_limit}
#SBATCH --output={output_dir}/slurm_%j.log
#SBATCH --error={output_dir}/slurm_%j.err
# 环境设置
source $(conda info --base)/etc/profile.d/conda.sh
conda activate screen
# 设置Python路径
cd $SLURM_SUBMIT_DIR
export PYTHONPATH=$(pwd):$PYTHONPATH
# 运行分析
python {worker_script} \\
--tasks-file {tasks_file} \\
--output-dir {output_dir} \\
--num-workers {self.config.num_workers}
echo "Job completed at $(date)"
"""
return script
def submit_slurm_job(self, script_content: str, script_path: str = None) -> str:
"""提交SLURM作业"""
if script_path is None:
fd, script_path = tempfile.mkstemp(suffix='.sh')
os.close(fd)
with open(script_path, 'w') as f:
f.write(script_content)
os.chmod(script_path, 0o755)
result = subprocess.run(
['sbatch', script_path],
capture_output=True, text=True
)
if result.returncode == 0:
job_id = result.stdout.strip().split()[-1]
print(f"作业已提交: {job_id}")
return job_id
else:
raise RuntimeError(f"提交失败: {result.stderr}")