""" 工作进程:处理单个分析任务 设计为可以独立运行(用于SLURM作业数组) """ import os import pickle from typing import List, Tuple, Optional from dataclasses import asdict, fields from .structure_inspector import StructureInspector, StructureInfo def analyze_single_file(args: Tuple[str, str, set]) -> Optional[StructureInfo]: """ 分析单个CIF文件(Worker函数) Args: args: (file_path, target_cation, target_anions) Returns: StructureInfo 或 None(如果分析失败) """ file_path, target_cation, target_anions = args try: inspector = StructureInspector( target_cation=target_cation, target_anions=target_anions ) return inspector.inspect(file_path) except Exception as e: # 返回一个标记失败的结果 return StructureInfo( file_path=file_path, file_name=os.path.basename(file_path), is_valid=False, error_message=str(e) ) def structure_info_to_dict(info: StructureInfo) -> dict: """ 将 StructureInfo 转换为可序列化的字典 处理 set、dataclass 等特殊类型 """ result = {} for field in fields(info): value = getattr(info, field.name) # 处理 set 类型 if isinstance(value, set): result[field.name] = list(value) # 处理嵌套的 dataclass (如 ExpansionInfo) elif hasattr(value, '__dataclass_fields__'): result[field.name] = asdict(value) # 处理 list 中可能包含的 dataclass elif isinstance(value, list): result[field.name] = [ asdict(item) if hasattr(item, '__dataclass_fields__') else item for item in value ] else: result[field.name] = value return result def dict_to_structure_info(d: dict) -> StructureInfo: """ 从字典恢复 StructureInfo 对象 """ from .structure_inspector import ExpansionInfo, OccupancyInfo # 处理 set 类型字段 if 'elements' in d and isinstance(d['elements'], list): d['elements'] = set(d['elements']) if 'anion_types' in d and isinstance(d['anion_types'], list): d['anion_types'] = set(d['anion_types']) if 'target_anions' in d and isinstance(d['target_anions'], list): d['target_anions'] = set(d['target_anions']) # 处理 ExpansionInfo if 'expansion_info' in d and isinstance(d['expansion_info'], dict): exp_dict = d['expansion_info'] # 处理 OccupancyInfo 列表 if 'occupancy_details' in exp_dict: exp_dict['occupancy_details'] = [ OccupancyInfo(**occ) if isinstance(occ, dict) else occ for occ in exp_dict['occupancy_details'] ] d['expansion_info'] = ExpansionInfo(**exp_dict) return StructureInfo(**d) def batch_analyze( file_paths: List[str], target_cation: str, target_anions: set, output_file: str = None ) -> List[StructureInfo]: """ 批量分析文件(用于SLURM子任务) Args: file_paths: CIF文件路径列表 target_cation: 目标阳离子 target_anions: 目标阴离子集合 output_file: 输出文件路径(pickle格式) Returns: StructureInfo列表 """ results = [] inspector = StructureInspector( target_cation=target_cation, target_anions=target_anions ) for file_path in file_paths: try: info = inspector.inspect(file_path) results.append(info) except Exception as e: results.append(StructureInfo( file_path=file_path, file_name=os.path.basename(file_path), is_valid=False, error_message=str(e) )) # 保存结果 if output_file: serializable_results = [structure_info_to_dict(r) for r in results] with open(output_file, 'wb') as f: pickle.dump(serializable_results, f) return results def load_results(result_file: str) -> List[StructureInfo]: """ 从pickle文件加载结果 """ with open(result_file, 'rb') as f: data = pickle.load(f) return [dict_to_structure_info(d) for d in data] def merge_results(result_files: List[str]) -> List[StructureInfo]: """ 合并多个结果文件(用于汇总SLURM作业数组的输出) """ all_results = [] for f in result_files: if os.path.exists(f): all_results.extend(load_results(f)) return all_results # 用于SLURM作业数组的命令行入口 if __name__ == "__main__": import argparse import json parser = argparse.ArgumentParser(description="CIF Analysis Worker") parser.add_argument("--tasks-file", required=True, help="任务文件路径(JSON)") parser.add_argument("--output-dir", required=True, help="输出目录") parser.add_argument("--task-id", type=int, default=0, help="任务ID(用于数组作业)") parser.add_argument("--num-workers", type=int, default=1, help="并行worker数") args = parser.parse_args() # 加载任务 with open(args.tasks_file, 'r') as f: task_config = json.load(f) file_paths = task_config['files'] target_cation = task_config['target_cation'] target_anions = set(task_config['target_anions']) # 如果是数组作业,只处理分配的部分 if args.task_id >= 0: chunk_size = len(file_paths) // args.num_workers + 1 start_idx = args.task_id * chunk_size end_idx = min(start_idx + chunk_size, len(file_paths)) file_paths = file_paths[start_idx:end_idx] # 输出文件 output_file = os.path.join(args.output_dir, f"results_{args.task_id}.pkl") # 执行分析 print(f"Worker {args.task_id}: 处理 {len(file_paths)} 个文件") results = batch_analyze(file_paths, target_cation, target_anions, output_file) print(f"Worker {args.task_id}: 完成,结果保存到 {output_file}")