预处理增加并行计算

This commit is contained in:
2025-12-14 15:42:13 +08:00
parent c91998662a
commit ae4e7280b4
5 changed files with 720 additions and 147 deletions

120
src/analysis/worker.py Normal file
View File

@@ -0,0 +1,120 @@
"""
工作进程:处理单个分析任务
设计为可以独立运行用于SLURM作业数组
"""
import os
import pickle
from typing import List, Tuple, Optional
from dataclasses import asdict
from .structure_inspector import StructureInspector, StructureInfo
def analyze_single_file(args: Tuple[str, str, set]) -> Optional[StructureInfo]:
"""
分析单个CIF文件Worker函数
Args:
args: (file_path, target_cation, target_anions)
Returns:
StructureInfo 或 None如果分析失败
"""
file_path, target_cation, target_anions = args
try:
inspector = StructureInspector(
target_cation=target_cation,
target_anions=target_anions
)
return inspector.inspect(file_path)
except Exception as e:
# 返回一个标记失败的结果
return StructureInfo(
file_path=file_path,
file_name=os.path.basename(file_path),
is_valid=False,
error_message=str(e)
)
def batch_analyze(
file_paths: List[str],
target_cation: str,
target_anions: set,
output_file: str = None
) -> List[StructureInfo]:
"""
批量分析文件用于SLURM子任务
Args:
file_paths: CIF文件路径列表
target_cation: 目标阳离子
target_anions: 目标阴离子集合
output_file: 输出文件路径pickle格式
Returns:
StructureInfo列表
"""
results = []
inspector = StructureInspector(
target_cation=target_cation,
target_anions=target_anions
)
for file_path in file_paths:
try:
info = inspector.inspect(file_path)
results.append(info)
except Exception as e:
results.append(StructureInfo(
file_path=file_path,
file_name=os.path.basename(file_path),
is_valid=False,
error_message=str(e)
))
# 保存结果
if output_file:
with open(output_file, 'wb') as f:
pickle.dump([asdict(r) for r in results], f)
return results
# 用于SLURM作业数组的命令行入口
if __name__ == "__main__":
import argparse
import json
parser = argparse.ArgumentParser(description="CIF Analysis Worker")
parser.add_argument("--tasks-file", required=True, help="任务文件路径(JSON)")
parser.add_argument("--output-dir", required=True, help="输出目录")
parser.add_argument("--task-id", type=int, default=0, help="任务ID(用于数组作业)")
parser.add_argument("--num-workers", type=int, default=1, help="并行worker数")
args = parser.parse_args()
# 加载任务
with open(args.tasks_file, 'r') as f:
task_config = json.load(f)
file_paths = task_config['files']
target_cation = task_config['target_cation']
target_anions = set(task_config['target_anions'])
# 如果是数组作业,只处理分配的部分
if args.task_id >= 0:
chunk_size = len(file_paths) // args.num_workers + 1
start_idx = args.task_id * chunk_size
end_idx = min(start_idx + chunk_size, len(file_paths))
file_paths = file_paths[start_idx:end_idx]
# 输出文件
output_file = os.path.join(args.output_dir, f"results_{args.task_id}.pkl")
# 执行分析
print(f"Worker {args.task_id}: 处理 {len(file_paths)} 个文件")
results = batch_analyze(file_paths, target_cation, target_anions, output_file)
print(f"Worker {args.task_id}: 完成,结果保存到 {output_file}")