预处理增加并行计算
This commit is contained in:
120
src/analysis/worker.py
Normal file
120
src/analysis/worker.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""
|
||||
工作进程:处理单个分析任务
|
||||
设计为可以独立运行(用于SLURM作业数组)
|
||||
"""
|
||||
import os
|
||||
import pickle
|
||||
from typing import List, Tuple, Optional
|
||||
from dataclasses import asdict
|
||||
|
||||
from .structure_inspector import StructureInspector, StructureInfo
|
||||
|
||||
|
||||
def analyze_single_file(args: Tuple[str, str, set]) -> Optional[StructureInfo]:
|
||||
"""
|
||||
分析单个CIF文件(Worker函数)
|
||||
|
||||
Args:
|
||||
args: (file_path, target_cation, target_anions)
|
||||
|
||||
Returns:
|
||||
StructureInfo 或 None(如果分析失败)
|
||||
"""
|
||||
file_path, target_cation, target_anions = args
|
||||
|
||||
try:
|
||||
inspector = StructureInspector(
|
||||
target_cation=target_cation,
|
||||
target_anions=target_anions
|
||||
)
|
||||
return inspector.inspect(file_path)
|
||||
except Exception as e:
|
||||
# 返回一个标记失败的结果
|
||||
return StructureInfo(
|
||||
file_path=file_path,
|
||||
file_name=os.path.basename(file_path),
|
||||
is_valid=False,
|
||||
error_message=str(e)
|
||||
)
|
||||
|
||||
|
||||
def batch_analyze(
|
||||
file_paths: List[str],
|
||||
target_cation: str,
|
||||
target_anions: set,
|
||||
output_file: str = None
|
||||
) -> List[StructureInfo]:
|
||||
"""
|
||||
批量分析文件(用于SLURM子任务)
|
||||
|
||||
Args:
|
||||
file_paths: CIF文件路径列表
|
||||
target_cation: 目标阳离子
|
||||
target_anions: 目标阴离子集合
|
||||
output_file: 输出文件路径(pickle格式)
|
||||
|
||||
Returns:
|
||||
StructureInfo列表
|
||||
"""
|
||||
results = []
|
||||
|
||||
inspector = StructureInspector(
|
||||
target_cation=target_cation,
|
||||
target_anions=target_anions
|
||||
)
|
||||
|
||||
for file_path in file_paths:
|
||||
try:
|
||||
info = inspector.inspect(file_path)
|
||||
results.append(info)
|
||||
except Exception as e:
|
||||
results.append(StructureInfo(
|
||||
file_path=file_path,
|
||||
file_name=os.path.basename(file_path),
|
||||
is_valid=False,
|
||||
error_message=str(e)
|
||||
))
|
||||
|
||||
# 保存结果
|
||||
if output_file:
|
||||
with open(output_file, 'wb') as f:
|
||||
pickle.dump([asdict(r) for r in results], f)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# 用于SLURM作业数组的命令行入口
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
import json
|
||||
|
||||
parser = argparse.ArgumentParser(description="CIF Analysis Worker")
|
||||
parser.add_argument("--tasks-file", required=True, help="任务文件路径(JSON)")
|
||||
parser.add_argument("--output-dir", required=True, help="输出目录")
|
||||
parser.add_argument("--task-id", type=int, default=0, help="任务ID(用于数组作业)")
|
||||
parser.add_argument("--num-workers", type=int, default=1, help="并行worker数")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 加载任务
|
||||
with open(args.tasks_file, 'r') as f:
|
||||
task_config = json.load(f)
|
||||
|
||||
file_paths = task_config['files']
|
||||
target_cation = task_config['target_cation']
|
||||
target_anions = set(task_config['target_anions'])
|
||||
|
||||
# 如果是数组作业,只处理分配的部分
|
||||
if args.task_id >= 0:
|
||||
chunk_size = len(file_paths) // args.num_workers + 1
|
||||
start_idx = args.task_id * chunk_size
|
||||
end_idx = min(start_idx + chunk_size, len(file_paths))
|
||||
file_paths = file_paths[start_idx:end_idx]
|
||||
|
||||
# 输出文件
|
||||
output_file = os.path.join(args.output_dir, f"results_{args.task_id}.pkl")
|
||||
|
||||
# 执行分析
|
||||
print(f"Worker {args.task_id}: 处理 {len(file_paths)} 个文件")
|
||||
results = batch_analyze(file_paths, target_cation, target_anions, output_file)
|
||||
print(f"Worker {args.task_id}: 完成,结果保存到 {output_file}")
|
||||
Reference in New Issue
Block a user