Files
screen/src/analysis/worker.py
2025-12-14 17:57:42 +08:00

199 lines
6.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
工作进程:处理单个分析任务
设计为可以独立运行用于SLURM作业数组
"""
import os
import pickle
from typing import List, Tuple, Optional
from dataclasses import asdict, fields
from .structure_inspector import StructureInspector, StructureInfo
def analyze_single_file(args: Tuple[str, str, set]) -> Optional[StructureInfo]:
"""
分析单个CIF文件Worker函数
Args:
args: (file_path, target_cation, target_anions)
Returns:
StructureInfo 或 None如果分析失败
"""
file_path, target_cation, target_anions = args
try:
inspector = StructureInspector(
target_cation=target_cation,
target_anions=target_anions
)
return inspector.inspect(file_path)
except Exception as e:
# 返回一个标记失败的结果
return StructureInfo(
file_path=file_path,
file_name=os.path.basename(file_path),
is_valid=False,
error_message=str(e)
)
def structure_info_to_dict(info: StructureInfo) -> dict:
"""
将 StructureInfo 转换为可序列化的字典
处理 set、dataclass 等特殊类型
"""
result = {}
for field in fields(info):
value = getattr(info, field.name)
# 处理 set 类型
if isinstance(value, set):
result[field.name] = list(value)
# 处理嵌套的 dataclass (如 ExpansionInfo)
elif hasattr(value, '__dataclass_fields__'):
result[field.name] = asdict(value)
# 处理 list 中可能包含的 dataclass
elif isinstance(value, list):
result[field.name] = [
asdict(item) if hasattr(item, '__dataclass_fields__') else item
for item in value
]
else:
result[field.name] = value
return result
def dict_to_structure_info(d: dict) -> StructureInfo:
"""
从字典恢复 StructureInfo 对象
"""
from .structure_inspector import ExpansionInfo, OccupancyInfo
# 处理 set 类型字段
if 'elements' in d and isinstance(d['elements'], list):
d['elements'] = set(d['elements'])
if 'anion_types' in d and isinstance(d['anion_types'], list):
d['anion_types'] = set(d['anion_types'])
if 'target_anions' in d and isinstance(d['target_anions'], list):
d['target_anions'] = set(d['target_anions'])
# 处理 ExpansionInfo
if 'expansion_info' in d and isinstance(d['expansion_info'], dict):
exp_dict = d['expansion_info']
# 处理 OccupancyInfo 列表
if 'occupancy_details' in exp_dict:
exp_dict['occupancy_details'] = [
OccupancyInfo(**occ) if isinstance(occ, dict) else occ
for occ in exp_dict['occupancy_details']
]
d['expansion_info'] = ExpansionInfo(**exp_dict)
return StructureInfo(**d)
def batch_analyze(
file_paths: List[str],
target_cation: str,
target_anions: set,
output_file: str = None
) -> List[StructureInfo]:
"""
批量分析文件用于SLURM子任务
Args:
file_paths: CIF文件路径列表
target_cation: 目标阳离子
target_anions: 目标阴离子集合
output_file: 输出文件路径pickle格式
Returns:
StructureInfo列表
"""
results = []
inspector = StructureInspector(
target_cation=target_cation,
target_anions=target_anions
)
for file_path in file_paths:
try:
info = inspector.inspect(file_path)
results.append(info)
except Exception as e:
results.append(StructureInfo(
file_path=file_path,
file_name=os.path.basename(file_path),
is_valid=False,
error_message=str(e)
))
# 保存结果
if output_file:
serializable_results = [structure_info_to_dict(r) for r in results]
with open(output_file, 'wb') as f:
pickle.dump(serializable_results, f)
return results
def load_results(result_file: str) -> List[StructureInfo]:
"""
从pickle文件加载结果
"""
with open(result_file, 'rb') as f:
data = pickle.load(f)
return [dict_to_structure_info(d) for d in data]
def merge_results(result_files: List[str]) -> List[StructureInfo]:
"""
合并多个结果文件用于汇总SLURM作业数组的输出
"""
all_results = []
for f in result_files:
if os.path.exists(f):
all_results.extend(load_results(f))
return all_results
# 用于SLURM作业数组的命令行入口
if __name__ == "__main__":
import argparse
import json
parser = argparse.ArgumentParser(description="CIF Analysis Worker")
parser.add_argument("--tasks-file", required=True, help="任务文件路径(JSON)")
parser.add_argument("--output-dir", required=True, help="输出目录")
parser.add_argument("--task-id", type=int, default=0, help="任务ID(用于数组作业)")
parser.add_argument("--num-workers", type=int, default=1, help="并行worker数")
args = parser.parse_args()
# 加载任务
with open(args.tasks_file, 'r') as f:
task_config = json.load(f)
file_paths = task_config['files']
target_cation = task_config['target_cation']
target_anions = set(task_config['target_anions'])
# 如果是数组作业,只处理分配的部分
if args.task_id >= 0:
chunk_size = len(file_paths) // args.num_workers + 1
start_idx = args.task_id * chunk_size
end_idx = min(start_idx + chunk_size, len(file_paths))
file_paths = file_paths[start_idx:end_idx]
# 输出文件
output_file = os.path.join(args.output_dir, f"results_{args.task_id}.pkl")
# 执行分析
print(f"Worker {args.task_id}: 处理 {len(file_paths)} 个文件")
results = batch_analyze(file_paths, target_cation, target_anions, output_file)
print(f"Worker {args.task_id}: 完成,结果保存到 {output_file}")