预处理增加并行计算
This commit is contained in:
157
main.py
157
main.py
@@ -1,71 +1,112 @@
|
|||||||
"""
|
"""
|
||||||
高通量筛选与扩胞项目 - 主入口
|
高通量筛选与扩胞项目 - 主入口(支持并行)
|
||||||
交互式命令行界面
|
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
# 添加 src 到路径
|
|
||||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
|
||||||
|
|
||||||
from src.analysis.database_analyzer import DatabaseAnalyzer
|
from analysis.database_analyzer import DatabaseAnalyzer
|
||||||
from src.analysis.report_generator import ReportGenerator
|
from analysis.report_generator import ReportGenerator
|
||||||
|
from core.scheduler import ParallelScheduler
|
||||||
|
|
||||||
|
|
||||||
def get_user_input():
|
def print_banner():
|
||||||
|
print("""
|
||||||
|
╔═══════════════════════════════════════════════════════════════════╗
|
||||||
|
║ 高通量筛选与扩胞项目 - 数据库分析工具 v2.0 ║
|
||||||
|
║ 支持高性能并行计算 ║
|
||||||
|
╚═══════════════════════════════════════════════════════════════════╝
|
||||||
|
""")
|
||||||
|
|
||||||
|
|
||||||
|
def detect_and_show_environment():
|
||||||
|
"""检测并显示环境信息"""
|
||||||
|
env = ParallelScheduler.detect_environment()
|
||||||
|
|
||||||
|
print("【运行环境检测】")
|
||||||
|
print(f" 主机名: {env['hostname']}")
|
||||||
|
print(f" 本地CPU核数: {env['total_cores']}")
|
||||||
|
print(f" SLURM集群: {'✅ 可用' if env['has_slurm'] else '❌ 不可用'}")
|
||||||
|
|
||||||
|
if env['has_slurm'] and env['slurm_partitions']:
|
||||||
|
print(f" 可用分区:")
|
||||||
|
for p in env['slurm_partitions']:
|
||||||
|
print(f" - {p['name']}: {p['nodes']}节点, {p['cores_per_node']}核/节点")
|
||||||
|
|
||||||
|
return env
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_input(env: dict):
|
||||||
"""获取用户输入"""
|
"""获取用户输入"""
|
||||||
print("\n" + "=" * 70)
|
|
||||||
print(" 高通量筛选与扩胞项目 - 数据库分析工具")
|
|
||||||
print("=" * 70)
|
|
||||||
|
|
||||||
# 1. 获取数据库路径
|
# 数据库路径
|
||||||
while True:
|
while True:
|
||||||
db_path = input("\n请输入数据库路径: ").strip()
|
db_path = input("\n📂 请输入数据库路径: ").strip()
|
||||||
if os.path.exists(db_path):
|
if os.path.exists(db_path):
|
||||||
break
|
break
|
||||||
print(f"❌ 路径不存在: {db_path}")
|
print(f"❌ 路径不存在: {db_path}")
|
||||||
|
|
||||||
# 2. 获取目标阳离子
|
# 目标阳离子
|
||||||
cation = input("请输入目标阳离子 [默认: Li]: ").strip() or "Li"
|
cation = input("🎯 请输入目标阳离子 [默认: Li]: ").strip() or "Li"
|
||||||
|
|
||||||
# 3. 获取目标阴离子
|
# 目标阴离子
|
||||||
anion_input = input("请输入目标阴离子 (用逗号分隔) [默认: O,S,Cl,Br]: ").strip()
|
anion_input = input("🎯 请输入目标阴离子 (逗号分隔) [默认: O,S,Cl,Br]: ").strip()
|
||||||
if anion_input:
|
anions = set(a.strip() for a in anion_input.split(',')) if anion_input else {'O', 'S', 'Cl', 'Br'}
|
||||||
anions = set(a.strip() for a in anion_input.split(','))
|
|
||||||
else:
|
|
||||||
anions = {'O', 'S', 'Cl', 'Br'}
|
|
||||||
|
|
||||||
# 4. 选择阴离子模式
|
# 阴离子模式
|
||||||
print("\n阴离子模式选择:")
|
print("\n阴离子模式:")
|
||||||
print(" 1. 仅单一阴离子化合物")
|
print(" 1. 仅单一阴离子")
|
||||||
print(" 2. 仅复合阴离子化合物")
|
print(" 2. 仅复合阴离子")
|
||||||
print(" 3. 全部 (默认)")
|
print(" 3. 全部 (默认)")
|
||||||
mode_choice = input("请选择 [1/2/3]: ").strip()
|
mode_choice = input("请选择 [1/2/3]: ").strip()
|
||||||
|
anion_mode = {'1': 'single', '2': 'mixed', '3': 'all', '': 'all'}.get(mode_choice, 'all')
|
||||||
|
|
||||||
mode_map = {'1': 'single', '2': 'mixed', '3': 'all', '': 'all'}
|
# 并行配置
|
||||||
anion_mode = mode_map.get(mode_choice, 'all')
|
print("\n" + "─" * 50)
|
||||||
|
print("【并行计算配置】")
|
||||||
|
|
||||||
# 5. 并行数
|
default_cores = min(env['total_cores'], 32)
|
||||||
n_jobs_input = input("并行线程数 [默认: 4]: ").strip()
|
cores_input = input(f"💻 最大可用核数 [默认: {default_cores}]: ").strip()
|
||||||
n_jobs = int(n_jobs_input) if n_jobs_input.isdigit() else 4
|
max_cores = int(cores_input) if cores_input.isdigit() else default_cores
|
||||||
|
|
||||||
|
print("\n任务复杂度 (影响每个Worker分配的核数):")
|
||||||
|
print(" 1. 低 (1核/Worker) - 简单IO操作")
|
||||||
|
print(" 2. 中 (2核/Worker) - 结构解析+检查 [默认]")
|
||||||
|
print(" 3. 高 (4核/Worker) - 复杂计算")
|
||||||
|
complexity_choice = input("请选择 [1/2/3]: ").strip()
|
||||||
|
complexity = {'1': 'low', '2': 'medium', '3': 'high', '': 'medium'}.get(complexity_choice, 'medium')
|
||||||
|
|
||||||
|
# 执行模式
|
||||||
|
use_slurm = False
|
||||||
|
if env['has_slurm']:
|
||||||
|
slurm_choice = input("\n是否使用SLURM提交作业? [y/N]: ").strip().lower()
|
||||||
|
use_slurm = slurm_choice == 'y'
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'database_path': db_path,
|
'database_path': db_path,
|
||||||
'target_cation': cation,
|
'target_cation': cation,
|
||||||
'target_anions': anions,
|
'target_anions': anions,
|
||||||
'anion_mode': anion_mode,
|
'anion_mode': anion_mode,
|
||||||
'n_jobs': n_jobs
|
'max_cores': max_cores,
|
||||||
|
'task_complexity': complexity,
|
||||||
|
'use_slurm': use_slurm
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""主函数"""
|
"""主函数"""
|
||||||
# 获取用户输入
|
print_banner()
|
||||||
params = get_user_input()
|
|
||||||
|
|
||||||
print("\n" + "-" * 70)
|
# 环境检测
|
||||||
print("开始分析数据库...")
|
env = detect_and_show_environment()
|
||||||
print("-" * 70)
|
|
||||||
|
# 获取用户输入
|
||||||
|
params = get_user_input(env)
|
||||||
|
|
||||||
|
print("\n" + "═" * 60)
|
||||||
|
print("开始数据库分析...")
|
||||||
|
print("═" * 60)
|
||||||
|
|
||||||
# 创建分析器
|
# 创建分析器
|
||||||
analyzer = DatabaseAnalyzer(
|
analyzer = DatabaseAnalyzer(
|
||||||
@@ -73,30 +114,42 @@ def main():
|
|||||||
target_cation=params['target_cation'],
|
target_cation=params['target_cation'],
|
||||||
target_anions=params['target_anions'],
|
target_anions=params['target_anions'],
|
||||||
anion_mode=params['anion_mode'],
|
anion_mode=params['anion_mode'],
|
||||||
n_jobs=params['n_jobs']
|
max_cores=params['max_cores'],
|
||||||
|
task_complexity=params['task_complexity']
|
||||||
)
|
)
|
||||||
|
|
||||||
# 执行分析
|
print(f"\n发现 {len(analyzer.cif_files)} 个CIF文件")
|
||||||
report = analyzer.analyze(show_progress=True)
|
|
||||||
|
|
||||||
# 打印报告
|
if params['use_slurm']:
|
||||||
ReportGenerator.print_report(report, detailed=True)
|
# SLURM模式
|
||||||
|
output_dir = input("输出目录 [默认: ./slurm_output]: ").strip() or "./slurm_output"
|
||||||
|
job_id = analyzer.analyze_slurm(output_dir=output_dir)
|
||||||
|
print(f"\n✅ SLURM作业已提交: {job_id}")
|
||||||
|
print(f" 使用 'squeue -j {job_id}' 查看状态")
|
||||||
|
print(f" 结果将保存到: {output_dir}")
|
||||||
|
else:
|
||||||
|
# 本地模式
|
||||||
|
report = analyzer.analyze(show_progress=True)
|
||||||
|
|
||||||
# 询问是否导出
|
# 打印报告
|
||||||
export = input("\n是否导出详细结果到CSV? [y/N]: ").strip().lower()
|
ReportGenerator.print_report(report, detailed=True)
|
||||||
if export == 'y':
|
|
||||||
output_path = input("输出文件路径 [默认: analysis_report.csv]: ").strip()
|
|
||||||
output_path = output_path or "analysis_report.csv"
|
|
||||||
ReportGenerator.export_to_csv(report, output_path)
|
|
||||||
|
|
||||||
# 询问是否继续处理
|
# 保存选项
|
||||||
print("\n" + "-" * 70)
|
save_choice = input("\n是否保存报告? [y/N]: ").strip().lower()
|
||||||
proceed = input("是否继续进行预处理? [y/N]: ").strip().lower()
|
if save_choice == 'y':
|
||||||
if proceed == 'y':
|
output_path = input("报告路径 [默认: analysis_report.json]: ").strip()
|
||||||
print("预处理功能将在下一阶段实现...")
|
output_path = output_path or "analysis_report.json"
|
||||||
# TODO: 调用预处理模块
|
report.save(output_path)
|
||||||
|
print(f"✅ 报告已保存到: {output_path}")
|
||||||
|
|
||||||
print("\n分析完成!")
|
# CSV导出
|
||||||
|
csv_choice = input("是否导出详细CSV? [y/N]: ").strip().lower()
|
||||||
|
if csv_choice == 'y':
|
||||||
|
csv_path = input("CSV路径 [默认: analysis_details.csv]: ").strip()
|
||||||
|
csv_path = csv_path or "analysis_details.csv"
|
||||||
|
ReportGenerator.export_to_csv(report, csv_path)
|
||||||
|
|
||||||
|
print("\n✅ 分析完成!")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -1,13 +1,16 @@
|
|||||||
"""
|
"""
|
||||||
数据库分析器:分析整个CIF数据库的构成和质量
|
数据库分析器:支持高性能并行分析
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass, field
|
import pickle
|
||||||
|
import json
|
||||||
|
from dataclasses import dataclass, field, asdict
|
||||||
from typing import Dict, List, Set, Optional
|
from typing import Dict, List, Set, Optional
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from pathlib import Path
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from .structure_inspector import StructureInspector, StructureInfo
|
from .structure_inspector import StructureInspector, StructureInfo
|
||||||
|
from .worker import analyze_single_file
|
||||||
|
from ..core.scheduler import ParallelScheduler, ResourceConfig
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -23,13 +26,13 @@ class DatabaseReport:
|
|||||||
# 目标元素统计
|
# 目标元素统计
|
||||||
target_cation: str = ""
|
target_cation: str = ""
|
||||||
target_anions: Set[str] = field(default_factory=set)
|
target_anions: Set[str] = field(default_factory=set)
|
||||||
anion_mode: str = "" # "single", "mixed", "all"
|
anion_mode: str = ""
|
||||||
|
|
||||||
# 含目标阳离子的统计
|
# 含目标阳离子的统计
|
||||||
cation_containing_count: int = 0
|
cation_containing_count: int = 0
|
||||||
cation_containing_ratio: float = 0.0
|
cation_containing_ratio: float = 0.0
|
||||||
|
|
||||||
# 阴离子分布 (在含目标阳离子的化合物中)
|
# 阴离子分布
|
||||||
anion_distribution: Dict[str, int] = field(default_factory=dict)
|
anion_distribution: Dict[str, int] = field(default_factory=dict)
|
||||||
anion_ratios: Dict[str, float] = field(default_factory=dict)
|
anion_ratios: Dict[str, float] = field(default_factory=dict)
|
||||||
single_anion_count: int = 0
|
single_anion_count: int = 0
|
||||||
@@ -38,11 +41,9 @@ class DatabaseReport:
|
|||||||
# 数据质量统计
|
# 数据质量统计
|
||||||
with_oxidation_states: int = 0
|
with_oxidation_states: int = 0
|
||||||
without_oxidation_states: int = 0
|
without_oxidation_states: int = 0
|
||||||
|
needs_expansion_count: int = 0
|
||||||
needs_expansion_count: int = 0 # 需要扩胞的数量
|
cation_partial_occupancy_count: int = 0
|
||||||
cation_partial_occupancy_count: int = 0 # 阳离子共占位
|
anion_partial_occupancy_count: int = 0
|
||||||
anion_partial_occupancy_count: int = 0 # 阴离子共占位
|
|
||||||
|
|
||||||
binary_compound_count: int = 0
|
binary_compound_count: int = 0
|
||||||
has_water_count: int = 0
|
has_water_count: int = 0
|
||||||
has_radioactive_count: int = 0
|
has_radioactive_count: int = 0
|
||||||
@@ -56,17 +57,39 @@ class DatabaseReport:
|
|||||||
all_structures: List[StructureInfo] = field(default_factory=list)
|
all_structures: List[StructureInfo] = field(default_factory=list)
|
||||||
skip_reasons_summary: Dict[str, int] = field(default_factory=dict)
|
skip_reasons_summary: Dict[str, int] = field(default_factory=dict)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
"""转换为可序列化的字典"""
|
||||||
|
d = asdict(self)
|
||||||
|
d['target_anions'] = list(self.target_anions)
|
||||||
|
d['all_structures'] = [asdict(s) for s in self.all_structures]
|
||||||
|
return d
|
||||||
|
|
||||||
|
def save(self, path: str):
|
||||||
|
"""保存报告"""
|
||||||
|
with open(path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, path: str) -> 'DatabaseReport':
|
||||||
|
"""加载报告"""
|
||||||
|
with open(path, 'r', encoding='utf-8') as f:
|
||||||
|
d = json.load(f)
|
||||||
|
d['target_anions'] = set(d['target_anions'])
|
||||||
|
d['all_structures'] = [StructureInfo(**s) for s in d['all_structures']]
|
||||||
|
return cls(**d)
|
||||||
|
|
||||||
|
|
||||||
class DatabaseAnalyzer:
|
class DatabaseAnalyzer:
|
||||||
"""数据库分析器"""
|
"""数据库分析器 - 支持高性能并行"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
database_path: str,
|
database_path: str,
|
||||||
target_cation: str = "Li",
|
target_cation: str = "Li",
|
||||||
target_anions: Set[str] = None,
|
target_anions: Set[str] = None,
|
||||||
anion_mode: str = "all", # "single", "mixed", "all"
|
anion_mode: str = "all",
|
||||||
n_jobs: int = 4
|
max_cores: int = 4,
|
||||||
|
task_complexity: str = "medium"
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
初始化分析器
|
初始化分析器
|
||||||
@@ -75,53 +98,27 @@ class DatabaseAnalyzer:
|
|||||||
database_path: 数据库路径
|
database_path: 数据库路径
|
||||||
target_cation: 目标阳离子
|
target_cation: 目标阳离子
|
||||||
target_anions: 目标阴离子集合
|
target_anions: 目标阴离子集合
|
||||||
anion_mode: 阴离子模式 ("single"=仅单一, "mixed"=仅复合, "all"=全部)
|
anion_mode: 阴离子模式
|
||||||
n_jobs: 并行数
|
max_cores: 最大可用核数
|
||||||
|
task_complexity: 任务复杂度 ('low', 'medium', 'high')
|
||||||
"""
|
"""
|
||||||
self.database_path = database_path
|
self.database_path = database_path
|
||||||
self.target_cation = target_cation
|
self.target_cation = target_cation
|
||||||
self.target_anions = target_anions or {'O', 'S', 'Cl', 'Br'}
|
self.target_anions = target_anions or {'O', 'S', 'Cl', 'Br'}
|
||||||
self.anion_mode = anion_mode
|
self.anion_mode = anion_mode
|
||||||
self.n_jobs = n_jobs
|
self.max_cores = max_cores
|
||||||
|
self.task_complexity = task_complexity
|
||||||
|
|
||||||
self.inspector = StructureInspector(
|
# 获取文件列表
|
||||||
target_cation=target_cation,
|
self.cif_files = self._get_cif_files()
|
||||||
target_anions=self.target_anions
|
|
||||||
|
# 配置调度器
|
||||||
|
self.resource_config = ParallelScheduler.recommend_config(
|
||||||
|
num_tasks=len(self.cif_files),
|
||||||
|
task_complexity=task_complexity,
|
||||||
|
max_cores=max_cores
|
||||||
)
|
)
|
||||||
|
self.scheduler = ParallelScheduler(self.resource_config)
|
||||||
def analyze(self, show_progress: bool = True) -> DatabaseReport:
|
|
||||||
"""
|
|
||||||
分析数据库
|
|
||||||
|
|
||||||
Args:
|
|
||||||
show_progress: 是否显示进度条
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
DatabaseReport: 分析报告
|
|
||||||
"""
|
|
||||||
report = DatabaseReport(
|
|
||||||
database_path=self.database_path,
|
|
||||||
target_cation=self.target_cation,
|
|
||||||
target_anions=self.target_anions,
|
|
||||||
anion_mode=self.anion_mode
|
|
||||||
)
|
|
||||||
|
|
||||||
# 获取所有CIF文件
|
|
||||||
cif_files = self._get_cif_files()
|
|
||||||
report.total_files = len(cif_files)
|
|
||||||
|
|
||||||
if report.total_files == 0:
|
|
||||||
print(f"警告: 在 {self.database_path} 中未找到CIF文件")
|
|
||||||
return report
|
|
||||||
|
|
||||||
# 并行分析所有文件
|
|
||||||
results = self._analyze_files(cif_files, show_progress)
|
|
||||||
report.all_structures = results
|
|
||||||
|
|
||||||
# 统计结果
|
|
||||||
self._compute_statistics(report)
|
|
||||||
|
|
||||||
return report
|
|
||||||
|
|
||||||
def _get_cif_files(self) -> List[str]:
|
def _get_cif_files(self) -> List[str]:
|
||||||
"""获取所有CIF文件路径"""
|
"""获取所有CIF文件路径"""
|
||||||
@@ -136,57 +133,111 @@ class DatabaseAnalyzer:
|
|||||||
if f.endswith('.cif'):
|
if f.endswith('.cif'):
|
||||||
cif_files.append(os.path.join(root, f))
|
cif_files.append(os.path.join(root, f))
|
||||||
|
|
||||||
return cif_files
|
return sorted(cif_files)
|
||||||
|
|
||||||
def _analyze_files(
|
def analyze(self, show_progress: bool = True) -> DatabaseReport:
|
||||||
self,
|
"""
|
||||||
cif_files: List[str],
|
执行并行分析
|
||||||
show_progress: bool
|
|
||||||
) -> List[StructureInfo]:
|
|
||||||
"""并行分析文件"""
|
|
||||||
results = []
|
|
||||||
|
|
||||||
if self.n_jobs == 1:
|
Args:
|
||||||
# 单线程
|
show_progress: 是否显示进度
|
||||||
iterator = tqdm(cif_files, desc="分析CIF文件") if show_progress else cif_files
|
|
||||||
for f in iterator:
|
|
||||||
results.append(self.inspector.inspect(f))
|
|
||||||
else:
|
|
||||||
# 多线程
|
|
||||||
with ThreadPoolExecutor(max_workers=self.n_jobs) as executor:
|
|
||||||
futures = {executor.submit(self.inspector.inspect, f): f for f in cif_files}
|
|
||||||
|
|
||||||
iterator = tqdm(as_completed(futures), total=len(futures), desc="分析CIF文件") \
|
Returns:
|
||||||
if show_progress else as_completed(futures)
|
DatabaseReport: 分析报告
|
||||||
|
"""
|
||||||
|
report = DatabaseReport(
|
||||||
|
database_path=self.database_path,
|
||||||
|
target_cation=self.target_cation,
|
||||||
|
target_anions=self.target_anions,
|
||||||
|
anion_mode=self.anion_mode,
|
||||||
|
total_files=len(self.cif_files)
|
||||||
|
)
|
||||||
|
|
||||||
for future in iterator:
|
if report.total_files == 0:
|
||||||
try:
|
print(f"⚠️ 警告: 在 {self.database_path} 中未找到CIF文件")
|
||||||
results.append(future.result())
|
return report
|
||||||
except Exception as e:
|
|
||||||
print(f"分析失败: {e}")
|
|
||||||
|
|
||||||
return results
|
# 准备任务
|
||||||
|
tasks = [
|
||||||
|
(f, self.target_cation, self.target_anions)
|
||||||
|
for f in self.cif_files
|
||||||
|
]
|
||||||
|
|
||||||
|
# 执行并行分析
|
||||||
|
results = self.scheduler.run_local(
|
||||||
|
tasks=tasks,
|
||||||
|
worker_func=analyze_single_file,
|
||||||
|
desc="分析CIF文件"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 过滤有效结果
|
||||||
|
report.all_structures = [r for r in results if r is not None]
|
||||||
|
|
||||||
|
# 统计结果
|
||||||
|
self._compute_statistics(report)
|
||||||
|
|
||||||
|
return report
|
||||||
|
|
||||||
|
def analyze_slurm(
|
||||||
|
self,
|
||||||
|
output_dir: str,
|
||||||
|
job_name: str = "cif_analysis"
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
提交SLURM作业进行分析
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_dir: 输出目录
|
||||||
|
job_name: 作业名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
作业ID
|
||||||
|
"""
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 保存任务配置
|
||||||
|
tasks_file = os.path.join(output_dir, "tasks.json")
|
||||||
|
with open(tasks_file, 'w') as f:
|
||||||
|
json.dump({
|
||||||
|
'files': self.cif_files,
|
||||||
|
'target_cation': self.target_cation,
|
||||||
|
'target_anions': list(self.target_anions),
|
||||||
|
'anion_mode': self.anion_mode
|
||||||
|
}, f)
|
||||||
|
|
||||||
|
# 生成SLURM脚本
|
||||||
|
worker_script = os.path.join(
|
||||||
|
os.path.dirname(__file__), 'worker.py'
|
||||||
|
)
|
||||||
|
script = self.scheduler.generate_slurm_script(
|
||||||
|
tasks_file=tasks_file,
|
||||||
|
worker_script=worker_script,
|
||||||
|
output_dir=output_dir,
|
||||||
|
job_name=job_name
|
||||||
|
)
|
||||||
|
|
||||||
|
# 保存并提交
|
||||||
|
script_path = os.path.join(output_dir, "submit.sh")
|
||||||
|
return self.scheduler.submit_slurm_job(script, script_path)
|
||||||
|
|
||||||
def _compute_statistics(self, report: DatabaseReport):
|
def _compute_statistics(self, report: DatabaseReport):
|
||||||
"""计算统计数据"""
|
"""计算统计数据"""
|
||||||
|
|
||||||
for info in report.all_structures:
|
for info in report.all_structures:
|
||||||
# 有效性统计
|
|
||||||
if info.is_valid:
|
if info.is_valid:
|
||||||
report.valid_files += 1
|
report.valid_files += 1
|
||||||
else:
|
else:
|
||||||
report.invalid_files += 1
|
report.invalid_files += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 含目标阳离子统计
|
|
||||||
if not info.contains_target_cation:
|
if not info.contains_target_cation:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
report.cation_containing_count += 1
|
report.cation_containing_count += 1
|
||||||
|
|
||||||
# 阴离子分布
|
|
||||||
for anion in info.anion_types:
|
for anion in info.anion_types:
|
||||||
report.anion_distribution[anion] = report.anion_distribution.get(anion, 0) + 1
|
report.anion_distribution[anion] = \
|
||||||
|
report.anion_distribution.get(anion, 0) + 1
|
||||||
|
|
||||||
if info.anion_mode == "single":
|
if info.anion_mode == "single":
|
||||||
report.single_anion_count += 1
|
report.single_anion_count += 1
|
||||||
@@ -201,21 +252,18 @@ class DatabaseAnalyzer:
|
|||||||
if info.anion_mode == "none":
|
if info.anion_mode == "none":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 氧化态统计
|
# 各项统计
|
||||||
if info.has_oxidation_states:
|
if info.has_oxidation_states:
|
||||||
report.with_oxidation_states += 1
|
report.with_oxidation_states += 1
|
||||||
else:
|
else:
|
||||||
report.without_oxidation_states += 1
|
report.without_oxidation_states += 1
|
||||||
|
|
||||||
# 共占位统计
|
|
||||||
if info.needs_expansion:
|
if info.needs_expansion:
|
||||||
report.needs_expansion_count += 1
|
report.needs_expansion_count += 1
|
||||||
if info.cation_has_partial_occupancy:
|
if info.cation_has_partial_occupancy:
|
||||||
report.cation_partial_occupancy_count += 1
|
report.cation_partial_occupancy_count += 1
|
||||||
if info.anion_has_partial_occupancy:
|
if info.anion_has_partial_occupancy:
|
||||||
report.anion_partial_occupancy_count += 1
|
report.anion_partial_occupancy_count += 1
|
||||||
|
|
||||||
# 其他问题统计
|
|
||||||
if info.is_binary_compound:
|
if info.is_binary_compound:
|
||||||
report.binary_compound_count += 1
|
report.binary_compound_count += 1
|
||||||
if info.has_water_molecule:
|
if info.has_water_molecule:
|
||||||
@@ -223,7 +271,7 @@ class DatabaseAnalyzer:
|
|||||||
if info.has_radioactive_elements:
|
if info.has_radioactive_elements:
|
||||||
report.has_radioactive_count += 1
|
report.has_radioactive_count += 1
|
||||||
|
|
||||||
# 可处理性统计
|
# 可处理性
|
||||||
if info.can_process:
|
if info.can_process:
|
||||||
if info.needs_expansion:
|
if info.needs_expansion:
|
||||||
report.needs_preprocessing += 1
|
report.needs_preprocessing += 1
|
||||||
@@ -231,7 +279,6 @@ class DatabaseAnalyzer:
|
|||||||
report.directly_processable += 1
|
report.directly_processable += 1
|
||||||
else:
|
else:
|
||||||
report.cannot_process += 1
|
report.cannot_process += 1
|
||||||
# 统计跳过原因
|
|
||||||
if info.skip_reason:
|
if info.skip_reason:
|
||||||
for reason in info.skip_reason.split("; "):
|
for reason in info.skip_reason.split("; "):
|
||||||
report.skip_reasons_summary[reason] = \
|
report.skip_reasons_summary[reason] = \
|
||||||
@@ -239,8 +286,10 @@ class DatabaseAnalyzer:
|
|||||||
|
|
||||||
# 计算比例
|
# 计算比例
|
||||||
if report.valid_files > 0:
|
if report.valid_files > 0:
|
||||||
report.cation_containing_ratio = report.cation_containing_count / report.valid_files
|
report.cation_containing_ratio = \
|
||||||
|
report.cation_containing_count / report.valid_files
|
||||||
|
|
||||||
if report.cation_containing_count > 0:
|
if report.cation_containing_count > 0:
|
||||||
for anion, count in report.anion_distribution.items():
|
for anion, count in report.anion_distribution.items():
|
||||||
report.anion_ratios[anion] = count / report.cation_containing_count
|
report.anion_ratios[anion] = \
|
||||||
|
count / report.cation_containing_count
|
||||||
120
src/analysis/worker.py
Normal file
120
src/analysis/worker.py
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
"""
|
||||||
|
工作进程:处理单个分析任务
|
||||||
|
设计为可以独立运行(用于SLURM作业数组)
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
from typing import List, Tuple, Optional
|
||||||
|
from dataclasses import asdict
|
||||||
|
|
||||||
|
from .structure_inspector import StructureInspector, StructureInfo
|
||||||
|
|
||||||
|
|
||||||
|
def analyze_single_file(args: Tuple[str, str, set]) -> Optional[StructureInfo]:
|
||||||
|
"""
|
||||||
|
分析单个CIF文件(Worker函数)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: (file_path, target_cation, target_anions)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StructureInfo 或 None(如果分析失败)
|
||||||
|
"""
|
||||||
|
file_path, target_cation, target_anions = args
|
||||||
|
|
||||||
|
try:
|
||||||
|
inspector = StructureInspector(
|
||||||
|
target_cation=target_cation,
|
||||||
|
target_anions=target_anions
|
||||||
|
)
|
||||||
|
return inspector.inspect(file_path)
|
||||||
|
except Exception as e:
|
||||||
|
# 返回一个标记失败的结果
|
||||||
|
return StructureInfo(
|
||||||
|
file_path=file_path,
|
||||||
|
file_name=os.path.basename(file_path),
|
||||||
|
is_valid=False,
|
||||||
|
error_message=str(e)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def batch_analyze(
|
||||||
|
file_paths: List[str],
|
||||||
|
target_cation: str,
|
||||||
|
target_anions: set,
|
||||||
|
output_file: str = None
|
||||||
|
) -> List[StructureInfo]:
|
||||||
|
"""
|
||||||
|
批量分析文件(用于SLURM子任务)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_paths: CIF文件路径列表
|
||||||
|
target_cation: 目标阳离子
|
||||||
|
target_anions: 目标阴离子集合
|
||||||
|
output_file: 输出文件路径(pickle格式)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StructureInfo列表
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
|
||||||
|
inspector = StructureInspector(
|
||||||
|
target_cation=target_cation,
|
||||||
|
target_anions=target_anions
|
||||||
|
)
|
||||||
|
|
||||||
|
for file_path in file_paths:
|
||||||
|
try:
|
||||||
|
info = inspector.inspect(file_path)
|
||||||
|
results.append(info)
|
||||||
|
except Exception as e:
|
||||||
|
results.append(StructureInfo(
|
||||||
|
file_path=file_path,
|
||||||
|
file_name=os.path.basename(file_path),
|
||||||
|
is_valid=False,
|
||||||
|
error_message=str(e)
|
||||||
|
))
|
||||||
|
|
||||||
|
# 保存结果
|
||||||
|
if output_file:
|
||||||
|
with open(output_file, 'wb') as f:
|
||||||
|
pickle.dump([asdict(r) for r in results], f)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
# 用于SLURM作业数组的命令行入口
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="CIF Analysis Worker")
|
||||||
|
parser.add_argument("--tasks-file", required=True, help="任务文件路径(JSON)")
|
||||||
|
parser.add_argument("--output-dir", required=True, help="输出目录")
|
||||||
|
parser.add_argument("--task-id", type=int, default=0, help="任务ID(用于数组作业)")
|
||||||
|
parser.add_argument("--num-workers", type=int, default=1, help="并行worker数")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# 加载任务
|
||||||
|
with open(args.tasks_file, 'r') as f:
|
||||||
|
task_config = json.load(f)
|
||||||
|
|
||||||
|
file_paths = task_config['files']
|
||||||
|
target_cation = task_config['target_cation']
|
||||||
|
target_anions = set(task_config['target_anions'])
|
||||||
|
|
||||||
|
# 如果是数组作业,只处理分配的部分
|
||||||
|
if args.task_id >= 0:
|
||||||
|
chunk_size = len(file_paths) // args.num_workers + 1
|
||||||
|
start_idx = args.task_id * chunk_size
|
||||||
|
end_idx = min(start_idx + chunk_size, len(file_paths))
|
||||||
|
file_paths = file_paths[start_idx:end_idx]
|
||||||
|
|
||||||
|
# 输出文件
|
||||||
|
output_file = os.path.join(args.output_dir, f"results_{args.task_id}.pkl")
|
||||||
|
|
||||||
|
# 执行分析
|
||||||
|
print(f"Worker {args.task_id}: 处理 {len(file_paths)} 个文件")
|
||||||
|
results = batch_analyze(file_paths, target_cation, target_anions, output_file)
|
||||||
|
print(f"Worker {args.task_id}: 完成,结果保存到 {output_file}")
|
||||||
115
src/core/progress.py
Normal file
115
src/core/progress.py
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
"""
|
||||||
|
进度管理器:支持多进程的实时进度显示
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from multiprocessing import Manager, Value
|
||||||
|
from typing import Optional
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
|
||||||
|
class ProgressManager:
|
||||||
|
"""多进程安全的进度管理器"""
|
||||||
|
|
||||||
|
def __init__(self, total: int, desc: str = "Processing"):
|
||||||
|
self.total = total
|
||||||
|
self.desc = desc
|
||||||
|
self.manager = Manager()
|
||||||
|
self.completed = self.manager.Value('i', 0)
|
||||||
|
self.failed = self.manager.Value('i', 0)
|
||||||
|
self.start_time = None
|
||||||
|
self._lock = self.manager.Lock()
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
"""开始计时"""
|
||||||
|
self.start_time = time.time()
|
||||||
|
|
||||||
|
def update(self, success: bool = True):
|
||||||
|
"""更新进度(进程安全)"""
|
||||||
|
with self._lock:
|
||||||
|
if success:
|
||||||
|
self.completed.value += 1
|
||||||
|
else:
|
||||||
|
self.failed.value += 1
|
||||||
|
|
||||||
|
def get_progress(self) -> dict:
|
||||||
|
"""获取当前进度"""
|
||||||
|
completed = self.completed.value
|
||||||
|
failed = self.failed.value
|
||||||
|
total_done = completed + failed
|
||||||
|
|
||||||
|
elapsed = time.time() - self.start_time if self.start_time else 0
|
||||||
|
|
||||||
|
if total_done > 0:
|
||||||
|
speed = total_done / elapsed # items/sec
|
||||||
|
remaining = (self.total - total_done) / speed if speed > 0 else 0
|
||||||
|
else:
|
||||||
|
speed = 0
|
||||||
|
remaining = 0
|
||||||
|
|
||||||
|
return {
|
||||||
|
'total': self.total,
|
||||||
|
'completed': completed,
|
||||||
|
'failed': failed,
|
||||||
|
'total_done': total_done,
|
||||||
|
'percent': total_done / self.total * 100 if self.total > 0 else 0,
|
||||||
|
'elapsed': elapsed,
|
||||||
|
'remaining': remaining,
|
||||||
|
'speed': speed
|
||||||
|
}
|
||||||
|
|
||||||
|
def display(self):
|
||||||
|
"""显示进度条"""
|
||||||
|
p = self.get_progress()
|
||||||
|
|
||||||
|
# 进度条
|
||||||
|
bar_width = 40
|
||||||
|
filled = int(bar_width * p['total_done'] / p['total']) if p['total'] > 0 else 0
|
||||||
|
bar = '█' * filled + '░' * (bar_width - filled)
|
||||||
|
|
||||||
|
# 时间格式化
|
||||||
|
elapsed_str = str(timedelta(seconds=int(p['elapsed'])))
|
||||||
|
remaining_str = str(timedelta(seconds=int(p['remaining'])))
|
||||||
|
|
||||||
|
# 构建显示字符串
|
||||||
|
status = (
|
||||||
|
f"\r{self.desc}: |{bar}| "
|
||||||
|
f"{p['total_done']}/{p['total']} ({p['percent']:.1f}%) "
|
||||||
|
f"[{elapsed_str}<{remaining_str}, {p['speed']:.1f}it/s] "
|
||||||
|
f"✓{p['completed']} ✗{p['failed']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
sys.stdout.write(status)
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
def finish(self):
|
||||||
|
"""完成显示"""
|
||||||
|
self.display()
|
||||||
|
print() # 换行
|
||||||
|
|
||||||
|
|
||||||
|
class ProgressReporter:
|
||||||
|
"""进度报告器(用于后台监控)"""
|
||||||
|
|
||||||
|
def __init__(self, progress_file: str = ".progress"):
|
||||||
|
self.progress_file = progress_file
|
||||||
|
|
||||||
|
def save(self, progress: dict):
|
||||||
|
"""保存进度到文件"""
|
||||||
|
import json
|
||||||
|
with open(self.progress_file, 'w') as f:
|
||||||
|
json.dump(progress, f)
|
||||||
|
|
||||||
|
def load(self) -> Optional[dict]:
|
||||||
|
"""从文件加载进度"""
|
||||||
|
import json
|
||||||
|
if os.path.exists(self.progress_file):
|
||||||
|
with open(self.progress_file, 'r') as f:
|
||||||
|
return json.load(f)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""清理进度文件"""
|
||||||
|
if os.path.exists(self.progress_file):
|
||||||
|
os.remove(self.progress_file)
|
||||||
236
src/core/scheduler.py
Normal file
236
src/core/scheduler.py
Normal file
@@ -0,0 +1,236 @@
|
|||||||
|
"""
|
||||||
|
并行调度器:支持本地多进程和SLURM集群调度
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
from typing import List, Callable, Any, Optional, Dict
|
||||||
|
from multiprocessing import Pool, cpu_count
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
import math
|
||||||
|
|
||||||
|
from .progress import ProgressManager
|
||||||
|
|
||||||
|
|
||||||
|
class ExecutionMode(Enum):
|
||||||
|
"""执行模式"""
|
||||||
|
LOCAL_SINGLE = "local_single" # 单线程
|
||||||
|
LOCAL_MULTI = "local_multi" # 本地多进程
|
||||||
|
SLURM_SINGLE = "slurm_single" # SLURM单节点
|
||||||
|
SLURM_ARRAY = "slurm_array" # SLURM作业数组
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ResourceConfig:
|
||||||
|
"""资源配置"""
|
||||||
|
max_cores: int = 4 # 最大可用核数
|
||||||
|
cores_per_worker: int = 1 # 每个worker使用的核数
|
||||||
|
memory_per_core: str = "4G" # 每核内存
|
||||||
|
partition: str = "cpu" # SLURM分区
|
||||||
|
time_limit: str = "7-00:00:00" # 时间限制
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_workers(self) -> int:
|
||||||
|
"""计算worker数量"""
|
||||||
|
return max(1, self.max_cores // self.cores_per_worker)
|
||||||
|
|
||||||
|
|
||||||
|
class ParallelScheduler:
|
||||||
|
"""并行调度器"""
|
||||||
|
|
||||||
|
# 根据任务复杂度推荐的核数配置
|
||||||
|
COMPLEXITY_CORES = {
|
||||||
|
'low': 1, # 简单IO操作
|
||||||
|
'medium': 2, # 结构解析+基础检查
|
||||||
|
'high': 4, # 复杂计算(扩胞等)
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, resource_config: ResourceConfig = None):
|
||||||
|
self.config = resource_config or ResourceConfig()
|
||||||
|
self.progress_manager = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def detect_environment() -> Dict[str, Any]:
|
||||||
|
"""检测运行环境"""
|
||||||
|
env_info = {
|
||||||
|
'hostname': os.uname().nodename,
|
||||||
|
'total_cores': cpu_count(),
|
||||||
|
'has_slurm': False,
|
||||||
|
'slurm_partitions': [],
|
||||||
|
'available_nodes': 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 检测SLURM
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
['sinfo', '-h', '-o', '%P %a %c %D'],
|
||||||
|
capture_output=True, text=True, timeout=5
|
||||||
|
)
|
||||||
|
if result.returncode == 0:
|
||||||
|
env_info['has_slurm'] = True
|
||||||
|
lines = result.stdout.strip().split('\n')
|
||||||
|
for line in lines:
|
||||||
|
parts = line.split()
|
||||||
|
if len(parts) >= 4:
|
||||||
|
partition = parts[0].rstrip('*')
|
||||||
|
avail = parts[1]
|
||||||
|
cores = int(parts[2])
|
||||||
|
nodes = int(parts[3])
|
||||||
|
if avail == 'up':
|
||||||
|
env_info['slurm_partitions'].append({
|
||||||
|
'name': partition,
|
||||||
|
'cores_per_node': cores,
|
||||||
|
'nodes': nodes
|
||||||
|
})
|
||||||
|
env_info['available_nodes'] += nodes
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return env_info
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def recommend_config(
|
||||||
|
num_tasks: int,
|
||||||
|
task_complexity: str = 'medium',
|
||||||
|
max_cores: int = None
|
||||||
|
) -> ResourceConfig:
|
||||||
|
"""根据任务量和复杂度推荐配置"""
|
||||||
|
|
||||||
|
env = ParallelScheduler.detect_environment()
|
||||||
|
|
||||||
|
# 默认最大核数
|
||||||
|
if max_cores is None:
|
||||||
|
max_cores = min(env['total_cores'], 32) # 最多32核
|
||||||
|
|
||||||
|
# 每个worker的核数
|
||||||
|
cores_per_worker = ParallelScheduler.COMPLEXITY_CORES.get(task_complexity, 2)
|
||||||
|
|
||||||
|
# 计算最优worker数
|
||||||
|
# 原则:worker数 = min(任务数, 可用核数/每worker核数)
|
||||||
|
max_workers = max_cores // cores_per_worker
|
||||||
|
optimal_workers = min(num_tasks, max_workers)
|
||||||
|
|
||||||
|
# 重新分配核数
|
||||||
|
actual_cores = optimal_workers * cores_per_worker
|
||||||
|
|
||||||
|
config = ResourceConfig(
|
||||||
|
max_cores=actual_cores,
|
||||||
|
cores_per_worker=cores_per_worker,
|
||||||
|
)
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
def run_local(
|
||||||
|
self,
|
||||||
|
tasks: List[Any],
|
||||||
|
worker_func: Callable,
|
||||||
|
desc: str = "Processing"
|
||||||
|
) -> List[Any]:
|
||||||
|
"""本地多进程执行"""
|
||||||
|
|
||||||
|
num_workers = self.config.num_workers
|
||||||
|
total = len(tasks)
|
||||||
|
|
||||||
|
print(f"\n{'=' * 60}")
|
||||||
|
print(f"并行配置:")
|
||||||
|
print(f" 总任务数: {total}")
|
||||||
|
print(f" Worker数: {num_workers}")
|
||||||
|
print(f" 每Worker核数: {self.config.cores_per_worker}")
|
||||||
|
print(f" 总使用核数: {self.config.max_cores}")
|
||||||
|
print(f"{'=' * 60}\n")
|
||||||
|
|
||||||
|
# 初始化进度管理器
|
||||||
|
self.progress_manager = ProgressManager(total, desc)
|
||||||
|
self.progress_manager.start()
|
||||||
|
|
||||||
|
results = []
|
||||||
|
|
||||||
|
if num_workers == 1:
|
||||||
|
# 单进程模式
|
||||||
|
for task in tasks:
|
||||||
|
try:
|
||||||
|
result = worker_func(task)
|
||||||
|
results.append(result)
|
||||||
|
self.progress_manager.update(success=True)
|
||||||
|
except Exception as e:
|
||||||
|
results.append(None)
|
||||||
|
self.progress_manager.update(success=False)
|
||||||
|
self.progress_manager.display()
|
||||||
|
else:
|
||||||
|
# 多进程模式
|
||||||
|
with Pool(processes=num_workers) as pool:
|
||||||
|
# 使用imap_unordered获取更好的性能
|
||||||
|
for result in pool.imap_unordered(worker_func, tasks, chunksize=10):
|
||||||
|
if result is not None:
|
||||||
|
results.append(result)
|
||||||
|
self.progress_manager.update(success=True)
|
||||||
|
else:
|
||||||
|
self.progress_manager.update(success=False)
|
||||||
|
self.progress_manager.display()
|
||||||
|
|
||||||
|
self.progress_manager.finish()
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def generate_slurm_script(
|
||||||
|
self,
|
||||||
|
tasks_file: str,
|
||||||
|
worker_script: str,
|
||||||
|
output_dir: str,
|
||||||
|
job_name: str = "analysis"
|
||||||
|
) -> str:
|
||||||
|
"""生成SLURM作业脚本"""
|
||||||
|
|
||||||
|
script = f"""#!/bin/bash
|
||||||
|
#SBATCH --job-name={job_name}
|
||||||
|
#SBATCH --partition={self.config.partition}
|
||||||
|
#SBATCH --nodes=1
|
||||||
|
#SBATCH --ntasks=1
|
||||||
|
#SBATCH --cpus-per-task={self.config.max_cores}
|
||||||
|
#SBATCH --mem-per-cpu={self.config.memory_per_core}
|
||||||
|
#SBATCH --time={self.config.time_limit}
|
||||||
|
#SBATCH --output={output_dir}/slurm_%j.log
|
||||||
|
#SBATCH --error={output_dir}/slurm_%j.err
|
||||||
|
|
||||||
|
# 环境设置
|
||||||
|
source $(conda info --base)/etc/profile.d/conda.sh
|
||||||
|
conda activate screen
|
||||||
|
|
||||||
|
# 设置Python路径
|
||||||
|
cd $SLURM_SUBMIT_DIR
|
||||||
|
export PYTHONPATH=$(pwd):$PYTHONPATH
|
||||||
|
|
||||||
|
# 运行分析
|
||||||
|
python {worker_script} \\
|
||||||
|
--tasks-file {tasks_file} \\
|
||||||
|
--output-dir {output_dir} \\
|
||||||
|
--num-workers {self.config.num_workers}
|
||||||
|
|
||||||
|
echo "Job completed at $(date)"
|
||||||
|
"""
|
||||||
|
return script
|
||||||
|
|
||||||
|
def submit_slurm_job(self, script_content: str, script_path: str = None) -> str:
|
||||||
|
"""提交SLURM作业"""
|
||||||
|
|
||||||
|
if script_path is None:
|
||||||
|
fd, script_path = tempfile.mkstemp(suffix='.sh')
|
||||||
|
os.close(fd)
|
||||||
|
|
||||||
|
with open(script_path, 'w') as f:
|
||||||
|
f.write(script_content)
|
||||||
|
|
||||||
|
os.chmod(script_path, 0o755)
|
||||||
|
|
||||||
|
result = subprocess.run(
|
||||||
|
['sbatch', script_path],
|
||||||
|
capture_output=True, text=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.returncode == 0:
|
||||||
|
job_id = result.stdout.strip().split()[-1]
|
||||||
|
print(f"作业已提交: {job_id}")
|
||||||
|
return job_id
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"提交失败: {result.stderr}")
|
||||||
Reference in New Issue
Block a user