增加扩胞逻辑

This commit is contained in:
2025-12-14 17:57:42 +08:00
parent 2378a3f2a2
commit 9b36aa10ff
4 changed files with 310 additions and 153 deletions

View File

@@ -5,7 +5,7 @@
import os
import pickle
from typing import List, Tuple, Optional
from dataclasses import asdict
from dataclasses import asdict, fields
from .structure_inspector import StructureInspector, StructureInfo
@@ -38,11 +38,68 @@ def analyze_single_file(args: Tuple[str, str, set]) -> Optional[StructureInfo]:
)
def structure_info_to_dict(info: StructureInfo) -> dict:
"""
将 StructureInfo 转换为可序列化的字典
处理 set、dataclass 等特殊类型
"""
result = {}
for field in fields(info):
value = getattr(info, field.name)
# 处理 set 类型
if isinstance(value, set):
result[field.name] = list(value)
# 处理嵌套的 dataclass (如 ExpansionInfo)
elif hasattr(value, '__dataclass_fields__'):
result[field.name] = asdict(value)
# 处理 list 中可能包含的 dataclass
elif isinstance(value, list):
result[field.name] = [
asdict(item) if hasattr(item, '__dataclass_fields__') else item
for item in value
]
else:
result[field.name] = value
return result
def dict_to_structure_info(d: dict) -> StructureInfo:
"""
从字典恢复 StructureInfo 对象
"""
from .structure_inspector import ExpansionInfo, OccupancyInfo
# 处理 set 类型字段
if 'elements' in d and isinstance(d['elements'], list):
d['elements'] = set(d['elements'])
if 'anion_types' in d and isinstance(d['anion_types'], list):
d['anion_types'] = set(d['anion_types'])
if 'target_anions' in d and isinstance(d['target_anions'], list):
d['target_anions'] = set(d['target_anions'])
# 处理 ExpansionInfo
if 'expansion_info' in d and isinstance(d['expansion_info'], dict):
exp_dict = d['expansion_info']
# 处理 OccupancyInfo 列表
if 'occupancy_details' in exp_dict:
exp_dict['occupancy_details'] = [
OccupancyInfo(**occ) if isinstance(occ, dict) else occ
for occ in exp_dict['occupancy_details']
]
d['expansion_info'] = ExpansionInfo(**exp_dict)
return StructureInfo(**d)
def batch_analyze(
file_paths: List[str],
target_cation: str,
target_anions: set,
output_file: str = None
file_paths: List[str],
target_cation: str,
target_anions: set,
output_file: str = None
) -> List[StructureInfo]:
"""
批量分析文件用于SLURM子任务
@@ -77,12 +134,34 @@ def batch_analyze(
# 保存结果
if output_file:
serializable_results = [structure_info_to_dict(r) for r in results]
with open(output_file, 'wb') as f:
pickle.dump([asdict(r) for r in results], f)
pickle.dump(serializable_results, f)
return results
def load_results(result_file: str) -> List[StructureInfo]:
"""
从pickle文件加载结果
"""
with open(result_file, 'rb') as f:
data = pickle.load(f)
return [dict_to_structure_info(d) for d in data]
def merge_results(result_files: List[str]) -> List[StructureInfo]:
"""
合并多个结果文件用于汇总SLURM作业数组的输出
"""
all_results = []
for f in result_files:
if os.path.exists(f):
all_results.extend(load_results(f))
return all_results
# 用于SLURM作业数组的命令行入口
if __name__ == "__main__":
import argparse