199 lines
6.0 KiB
Python
199 lines
6.0 KiB
Python
"""
|
||
工作进程:处理单个分析任务
|
||
设计为可以独立运行(用于SLURM作业数组)
|
||
"""
|
||
import os
|
||
import pickle
|
||
from typing import List, Tuple, Optional
|
||
from dataclasses import asdict, fields
|
||
|
||
from .structure_inspector import StructureInspector, StructureInfo
|
||
|
||
|
||
def analyze_single_file(args: Tuple[str, str, set]) -> Optional[StructureInfo]:
|
||
"""
|
||
分析单个CIF文件(Worker函数)
|
||
|
||
Args:
|
||
args: (file_path, target_cation, target_anions)
|
||
|
||
Returns:
|
||
StructureInfo 或 None(如果分析失败)
|
||
"""
|
||
file_path, target_cation, target_anions = args
|
||
|
||
try:
|
||
inspector = StructureInspector(
|
||
target_cation=target_cation,
|
||
target_anions=target_anions
|
||
)
|
||
return inspector.inspect(file_path)
|
||
except Exception as e:
|
||
# 返回一个标记失败的结果
|
||
return StructureInfo(
|
||
file_path=file_path,
|
||
file_name=os.path.basename(file_path),
|
||
is_valid=False,
|
||
error_message=str(e)
|
||
)
|
||
|
||
|
||
def structure_info_to_dict(info: StructureInfo) -> dict:
|
||
"""
|
||
将 StructureInfo 转换为可序列化的字典
|
||
处理 set、dataclass 等特殊类型
|
||
"""
|
||
result = {}
|
||
for field in fields(info):
|
||
value = getattr(info, field.name)
|
||
|
||
# 处理 set 类型
|
||
if isinstance(value, set):
|
||
result[field.name] = list(value)
|
||
# 处理嵌套的 dataclass (如 ExpansionInfo)
|
||
elif hasattr(value, '__dataclass_fields__'):
|
||
result[field.name] = asdict(value)
|
||
# 处理 list 中可能包含的 dataclass
|
||
elif isinstance(value, list):
|
||
result[field.name] = [
|
||
asdict(item) if hasattr(item, '__dataclass_fields__') else item
|
||
for item in value
|
||
]
|
||
else:
|
||
result[field.name] = value
|
||
|
||
return result
|
||
|
||
|
||
def dict_to_structure_info(d: dict) -> StructureInfo:
|
||
"""
|
||
从字典恢复 StructureInfo 对象
|
||
"""
|
||
from .structure_inspector import ExpansionInfo, OccupancyInfo
|
||
|
||
# 处理 set 类型字段
|
||
if 'elements' in d and isinstance(d['elements'], list):
|
||
d['elements'] = set(d['elements'])
|
||
if 'anion_types' in d and isinstance(d['anion_types'], list):
|
||
d['anion_types'] = set(d['anion_types'])
|
||
if 'target_anions' in d and isinstance(d['target_anions'], list):
|
||
d['target_anions'] = set(d['target_anions'])
|
||
|
||
# 处理 ExpansionInfo
|
||
if 'expansion_info' in d and isinstance(d['expansion_info'], dict):
|
||
exp_dict = d['expansion_info']
|
||
|
||
# 处理 OccupancyInfo 列表
|
||
if 'occupancy_details' in exp_dict:
|
||
exp_dict['occupancy_details'] = [
|
||
OccupancyInfo(**occ) if isinstance(occ, dict) else occ
|
||
for occ in exp_dict['occupancy_details']
|
||
]
|
||
|
||
d['expansion_info'] = ExpansionInfo(**exp_dict)
|
||
|
||
return StructureInfo(**d)
|
||
|
||
|
||
def batch_analyze(
|
||
file_paths: List[str],
|
||
target_cation: str,
|
||
target_anions: set,
|
||
output_file: str = None
|
||
) -> List[StructureInfo]:
|
||
"""
|
||
批量分析文件(用于SLURM子任务)
|
||
|
||
Args:
|
||
file_paths: CIF文件路径列表
|
||
target_cation: 目标阳离子
|
||
target_anions: 目标阴离子集合
|
||
output_file: 输出文件路径(pickle格式)
|
||
|
||
Returns:
|
||
StructureInfo列表
|
||
"""
|
||
results = []
|
||
|
||
inspector = StructureInspector(
|
||
target_cation=target_cation,
|
||
target_anions=target_anions
|
||
)
|
||
|
||
for file_path in file_paths:
|
||
try:
|
||
info = inspector.inspect(file_path)
|
||
results.append(info)
|
||
except Exception as e:
|
||
results.append(StructureInfo(
|
||
file_path=file_path,
|
||
file_name=os.path.basename(file_path),
|
||
is_valid=False,
|
||
error_message=str(e)
|
||
))
|
||
|
||
# 保存结果
|
||
if output_file:
|
||
serializable_results = [structure_info_to_dict(r) for r in results]
|
||
with open(output_file, 'wb') as f:
|
||
pickle.dump(serializable_results, f)
|
||
|
||
return results
|
||
|
||
|
||
def load_results(result_file: str) -> List[StructureInfo]:
|
||
"""
|
||
从pickle文件加载结果
|
||
"""
|
||
with open(result_file, 'rb') as f:
|
||
data = pickle.load(f)
|
||
|
||
return [dict_to_structure_info(d) for d in data]
|
||
|
||
|
||
def merge_results(result_files: List[str]) -> List[StructureInfo]:
|
||
"""
|
||
合并多个结果文件(用于汇总SLURM作业数组的输出)
|
||
"""
|
||
all_results = []
|
||
for f in result_files:
|
||
if os.path.exists(f):
|
||
all_results.extend(load_results(f))
|
||
return all_results
|
||
|
||
|
||
# 用于SLURM作业数组的命令行入口
|
||
if __name__ == "__main__":
|
||
import argparse
|
||
import json
|
||
|
||
parser = argparse.ArgumentParser(description="CIF Analysis Worker")
|
||
parser.add_argument("--tasks-file", required=True, help="任务文件路径(JSON)")
|
||
parser.add_argument("--output-dir", required=True, help="输出目录")
|
||
parser.add_argument("--task-id", type=int, default=0, help="任务ID(用于数组作业)")
|
||
parser.add_argument("--num-workers", type=int, default=1, help="并行worker数")
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 加载任务
|
||
with open(args.tasks_file, 'r') as f:
|
||
task_config = json.load(f)
|
||
|
||
file_paths = task_config['files']
|
||
target_cation = task_config['target_cation']
|
||
target_anions = set(task_config['target_anions'])
|
||
|
||
# 如果是数组作业,只处理分配的部分
|
||
if args.task_id >= 0:
|
||
chunk_size = len(file_paths) // args.num_workers + 1
|
||
start_idx = args.task_id * chunk_size
|
||
end_idx = min(start_idx + chunk_size, len(file_paths))
|
||
file_paths = file_paths[start_idx:end_idx]
|
||
|
||
# 输出文件
|
||
output_file = os.path.join(args.output_dir, f"results_{args.task_id}.pkl")
|
||
|
||
# 执行分析
|
||
print(f"Worker {args.task_id}: 处理 {len(file_paths)} 个文件")
|
||
results = batch_analyze(file_paths, target_cation, target_anions, output_file)
|
||
print(f"Worker {args.task_id}: 完成,结果保存到 {output_file}") |