增加扩胞逻辑
This commit is contained in:
@@ -5,7 +5,7 @@
|
||||
import os
|
||||
import pickle
|
||||
from typing import List, Tuple, Optional
|
||||
from dataclasses import asdict
|
||||
from dataclasses import asdict, fields
|
||||
|
||||
from .structure_inspector import StructureInspector, StructureInfo
|
||||
|
||||
@@ -38,11 +38,68 @@ def analyze_single_file(args: Tuple[str, str, set]) -> Optional[StructureInfo]:
|
||||
)
|
||||
|
||||
|
||||
def structure_info_to_dict(info: StructureInfo) -> dict:
|
||||
"""
|
||||
将 StructureInfo 转换为可序列化的字典
|
||||
处理 set、dataclass 等特殊类型
|
||||
"""
|
||||
result = {}
|
||||
for field in fields(info):
|
||||
value = getattr(info, field.name)
|
||||
|
||||
# 处理 set 类型
|
||||
if isinstance(value, set):
|
||||
result[field.name] = list(value)
|
||||
# 处理嵌套的 dataclass (如 ExpansionInfo)
|
||||
elif hasattr(value, '__dataclass_fields__'):
|
||||
result[field.name] = asdict(value)
|
||||
# 处理 list 中可能包含的 dataclass
|
||||
elif isinstance(value, list):
|
||||
result[field.name] = [
|
||||
asdict(item) if hasattr(item, '__dataclass_fields__') else item
|
||||
for item in value
|
||||
]
|
||||
else:
|
||||
result[field.name] = value
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def dict_to_structure_info(d: dict) -> StructureInfo:
|
||||
"""
|
||||
从字典恢复 StructureInfo 对象
|
||||
"""
|
||||
from .structure_inspector import ExpansionInfo, OccupancyInfo
|
||||
|
||||
# 处理 set 类型字段
|
||||
if 'elements' in d and isinstance(d['elements'], list):
|
||||
d['elements'] = set(d['elements'])
|
||||
if 'anion_types' in d and isinstance(d['anion_types'], list):
|
||||
d['anion_types'] = set(d['anion_types'])
|
||||
if 'target_anions' in d and isinstance(d['target_anions'], list):
|
||||
d['target_anions'] = set(d['target_anions'])
|
||||
|
||||
# 处理 ExpansionInfo
|
||||
if 'expansion_info' in d and isinstance(d['expansion_info'], dict):
|
||||
exp_dict = d['expansion_info']
|
||||
|
||||
# 处理 OccupancyInfo 列表
|
||||
if 'occupancy_details' in exp_dict:
|
||||
exp_dict['occupancy_details'] = [
|
||||
OccupancyInfo(**occ) if isinstance(occ, dict) else occ
|
||||
for occ in exp_dict['occupancy_details']
|
||||
]
|
||||
|
||||
d['expansion_info'] = ExpansionInfo(**exp_dict)
|
||||
|
||||
return StructureInfo(**d)
|
||||
|
||||
|
||||
def batch_analyze(
|
||||
file_paths: List[str],
|
||||
target_cation: str,
|
||||
target_anions: set,
|
||||
output_file: str = None
|
||||
file_paths: List[str],
|
||||
target_cation: str,
|
||||
target_anions: set,
|
||||
output_file: str = None
|
||||
) -> List[StructureInfo]:
|
||||
"""
|
||||
批量分析文件(用于SLURM子任务)
|
||||
@@ -77,12 +134,34 @@ def batch_analyze(
|
||||
|
||||
# 保存结果
|
||||
if output_file:
|
||||
serializable_results = [structure_info_to_dict(r) for r in results]
|
||||
with open(output_file, 'wb') as f:
|
||||
pickle.dump([asdict(r) for r in results], f)
|
||||
pickle.dump(serializable_results, f)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def load_results(result_file: str) -> List[StructureInfo]:
|
||||
"""
|
||||
从pickle文件加载结果
|
||||
"""
|
||||
with open(result_file, 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
|
||||
return [dict_to_structure_info(d) for d in data]
|
||||
|
||||
|
||||
def merge_results(result_files: List[str]) -> List[StructureInfo]:
|
||||
"""
|
||||
合并多个结果文件(用于汇总SLURM作业数组的输出)
|
||||
"""
|
||||
all_results = []
|
||||
for f in result_files:
|
||||
if os.path.exists(f):
|
||||
all_results.extend(load_results(f))
|
||||
return all_results
|
||||
|
||||
|
||||
# 用于SLURM作业数组的命令行入口
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
Reference in New Issue
Block a user