一阶段高筛制作完成

This commit is contained in:
koko
2025-12-16 11:36:49 +08:00
parent 6ea96c81d6
commit f78298e803
22 changed files with 3276 additions and 223 deletions

View File

@@ -0,0 +1,12 @@
"""
高通量筛选与扩胞项目 - 源代码包
"""
from . import analysis
from . import core
from . import preprocessing
from . import computation
from . import utils
__version__ = "2.2.0"
__all__ = ['analysis', 'core', 'preprocessing', 'computation', 'utils']

View File

@@ -71,6 +71,151 @@ class DatabaseReport:
})
expansion_factor_distribution: Dict[int, int] = field(default_factory=dict)
def to_dict(self) -> dict:
"""转换为可序列化的字典"""
from dataclasses import fields as dataclass_fields
def convert_value(val):
"""递归转换值为可序列化类型"""
if isinstance(val, set):
return list(val)
elif isinstance(val, dict):
return {k: convert_value(v) for k, v in val.items()}
elif isinstance(val, list):
return [convert_value(item) for item in val]
elif hasattr(val, '__dataclass_fields__'):
# 处理 dataclass 对象
return {k: convert_value(v) for k, v in asdict(val).items()}
else:
return val
result = {}
for f in dataclass_fields(self):
value = getattr(self, f.name)
result[f.name] = convert_value(value)
return result
def save(self, path: str):
"""保存报告到JSON文件"""
with open(path, 'w', encoding='utf-8') as f:
json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
print(f"✅ 报告已保存到: {path}")
@classmethod
def load(cls, path: str) -> 'DatabaseReport':
"""从JSON文件加载报告"""
with open(path, 'r', encoding='utf-8') as f:
d = json.load(f)
# 处理 set 类型
if 'target_anions' in d:
d['target_anions'] = set(d['target_anions'])
# 处理 StructureInfo 列表(简化处理,不恢复完整对象)
if 'all_structures' in d:
d['all_structures'] = []
return cls(**d)
def get_processable_files(self, include_needs_expansion: bool = True) -> List[StructureInfo]:
"""
获取可处理的文件列表
Args:
include_needs_expansion: 是否包含需要扩胞的文件
Returns:
可处理的 StructureInfo 列表
"""
result = []
for info in self.all_structures:
if info is None or not info.is_valid:
continue
if not info.contains_target_cation:
continue
if not info.can_process:
continue
if not include_needs_expansion and info.needs_expansion:
continue
result.append(info)
return result
def copy_processable_files(
self,
output_dir: str,
include_needs_expansion: bool = True,
organize_by_anion: bool = True
) -> Dict[str, int]:
"""
将可处理的CIF文件复制到工作区
Args:
output_dir: 输出目录(如 workspace/data
include_needs_expansion: 是否包含需要扩胞的文件
organize_by_anion: 是否按阴离子类型组织子目录
Returns:
复制统计信息 {类别: 数量}
"""
import shutil
# 创建输出目录
os.makedirs(output_dir, exist_ok=True)
# 获取可处理文件
processable = self.get_processable_files(include_needs_expansion)
stats = {
'direct': 0, # 可直接处理
'needs_expansion': 0, # 需要扩胞
'total': 0
}
# 按类型创建子目录
if organize_by_anion:
anion_dirs = {}
for info in processable:
# 确定目标目录
if organize_by_anion and info.anion_types:
# 使用主要阴离子作为目录名
anion_key = '+'.join(sorted(info.anion_types))
if anion_key not in anion_dirs:
anion_dir = os.path.join(output_dir, anion_key)
os.makedirs(anion_dir, exist_ok=True)
anion_dirs[anion_key] = anion_dir
target_dir = anion_dirs[anion_key]
else:
target_dir = output_dir
# 进一步按处理类型分类
if info.needs_expansion:
sub_dir = os.path.join(target_dir, 'needs_expansion')
stats['needs_expansion'] += 1
else:
sub_dir = os.path.join(target_dir, 'direct')
stats['direct'] += 1
os.makedirs(sub_dir, exist_ok=True)
# 复制文件
src_path = info.file_path
dst_path = os.path.join(sub_dir, info.file_name)
try:
shutil.copy2(src_path, dst_path)
stats['total'] += 1
except Exception as e:
print(f"⚠️ 复制失败 {info.file_name}: {e}")
# 打印统计
print(f"\n📁 文件已复制到: {output_dir}")
print(f" 可直接处理: {stats['direct']}")
print(f" 需要扩胞: {stats['needs_expansion']}")
print(f" 总计: {stats['total']}")
return stats
class DatabaseAnalyzer:
"""数据库分析器 - 支持高性能并行"""
@@ -232,6 +377,10 @@ class DatabaseAnalyzer:
report.invalid_files += 1
continue # 无效文件不继续统计
# 关键修复:只有当结构确实含有目标阳离子时才计入统计
if not info.contains_target_cation:
continue # 不含目标阳离子的文件不继续统计
report.cation_containing_count += 1
for anion in info.anion_types:
@@ -317,45 +466,3 @@ class DatabaseAnalyzer:
for anion, count in report.anion_distribution.items():
report.anion_ratios[anion] = \
count / report.cation_containing_count
def to_dict(self) -> dict:
"""转换为可序列化的字典"""
import json
from dataclasses import asdict, fields
result = {}
for field in fields(self):
value = getattr(self, field.name)
# 处理 set 类型
if isinstance(value, set):
result[field.name] = list(value)
# 处理 StructureInfo 列表
elif field.name == 'all_structures':
result[field.name] = [] # 不保存详细结构信息,太大
else:
result[field.name] = value
return result
def save(self, path: str):
"""保存报告到JSON文件"""
import json
with open(path, 'w', encoding='utf-8') as f:
json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
print(f"✅ 报告已保存到: {path}")
@classmethod
def load(cls, path: str) -> 'DatabaseReport':
"""从JSON文件加载报告"""
import json
with open(path, 'r', encoding='utf-8') as f:
d = json.load(f)
# 处理 set 类型
if 'target_anions' in d:
d['target_anions'] = set(d['target_anions'])
if 'all_structures' not in d:
d['all_structures'] = []
return cls(**d)

View File

@@ -145,7 +145,7 @@ class ReportGenerator:
info.anion_mode,
info.has_oxidation_states,
info.has_partial_occupancy,
info.cation_has_partial_occupancy,
info.cation_with_other_cation, # 修复:使用正确的属性名
info.anion_has_partial_occupancy,
info.needs_expansion,
info.is_binary_compound,
@@ -156,4 +156,4 @@ class ReportGenerator:
]
writer.writerow(row)
print(f"详细结果已导出到: {output_path}")
print(f"详细结果已导出到: {output_path}")

View File

@@ -441,43 +441,3 @@ class StructureInspector:
return False
except:
return False
def _evaluate_processability(self, info: StructureInfo):
"""评估可处理性"""
skip_reasons = []
if not info.is_valid:
skip_reasons.append("无法解析CIF文件")
if not info.contains_target_cation:
skip_reasons.append(f"不含{self.target_cation}")
if info.anion_mode == "none":
skip_reasons.append("不含目标阴离子")
if info.is_binary_compound:
skip_reasons.append("二元化合物")
if info.has_radioactive_elements:
skip_reasons.append("含放射性元素")
# 关键:目标阳离子共占位是不可处理的
if info.cation_has_partial_occupancy:
skip_reasons.append(f"{self.target_cation}存在共占位")
# 阴离子共占位通常也不处理
if info.anion_has_partial_occupancy:
skip_reasons.append("阴离子存在共占位")
if info.has_water_molecule:
skip_reasons.append("含水分子")
# 扩胞因子过大
if info.expansion_info.needs_expansion and not info.expansion_info.can_expand:
skip_reasons.append(info.expansion_info.skip_reason)
if skip_reasons:
info.can_process = False
info.skip_reason = "; ".join(skip_reasons)
else:
info.can_process = True

View File

@@ -0,0 +1,15 @@
"""
计算模块Zeo++ Voronoi 分析
"""
from .workspace_manager import WorkspaceManager
from .zeo_executor import ZeoExecutor, ZeoConfig
from .result_processor import ResultProcessor, FilterCriteria, StructureResult
__all__ = [
'WorkspaceManager',
'ZeoExecutor',
'ZeoConfig',
'ResultProcessor',
'FilterCriteria',
'StructureResult'
]

View File

@@ -0,0 +1,426 @@
"""
Zeo++ 计算结果处理器:提取数据、筛选结构
"""
import os
import re
import shutil
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass, field
import pandas as pd
@dataclass
class FilterCriteria:
"""筛选条件"""
min_percolation_diameter: float = 1.0 # 最小渗透直径 (Å),默认 1.0
min_d_value: float = 2.0 # 最小 d 值,默认 2.0
max_node_length: float = float('inf') # 最大节点长度 (Å)
@dataclass
class StructureResult:
"""单个结构的计算结果"""
structure_name: str
anion_type: str
work_dir: str
# 提取的参数
percolation_diameter: Optional[float] = None
min_d: Optional[float] = None
max_node_length: Optional[float] = None
# 筛选结果
passed_filter: bool = False
filter_reason: str = ""
class ResultProcessor:
"""
Zeo++ 计算结果处理器
功能:
1. 从每个结构目录的 log.txt 提取关键参数
2. 汇总所有结果到 CSV 文件
3. 根据筛选条件筛选结构
4. 将通过筛选的结构复制到新文件夹
"""
def __init__(
self,
workspace_path: str = "workspace",
data_dir: str = None,
output_dir: str = None
):
"""
初始化结果处理器
Args:
workspace_path: 工作区根目录
data_dir: 数据目录(默认 workspace/data
output_dir: 输出目录(默认 workspace/results
"""
self.workspace_path = os.path.abspath(workspace_path)
self.data_dir = data_dir or os.path.join(self.workspace_path, "data")
self.output_dir = output_dir or os.path.join(self.workspace_path, "results")
def extract_from_log(self, log_path: str) -> Tuple[Optional[float], Optional[float], Optional[float]]:
"""
从 log.txt 中提取三个关键参数
Args:
log_path: log.txt 文件路径
Returns:
(percolation_diameter, min_d, max_node_length)
"""
if not os.path.exists(log_path):
return None, None, None
try:
with open(log_path, 'r', encoding='utf-8') as f:
content = f.read()
except Exception:
return None, None, None
# 正则表达式 - 与 py/extract_data.py 保持一致
# 1. Percolation diameter: "# Percolation diameter (A): 1.06"
re_percolation = r"Percolation diameter \(A\):\s*([\d\.]+)"
# 2. Minimum of d: "the minium of d\n3.862140561244235"
# 注意:这是 Topological_Analysis 库输出的格式
re_min_d = r"the minium of d\s*\n\s*([\d\.]+)"
# 3. Maximum node length: "# Maximum node length detected: 1.332 A"
re_max_node = r"Maximum node length detected:\s*([\d\.]+)\s*A"
# 提取数据
match_perc = re.search(re_percolation, content)
match_d = re.search(re_min_d, content)
match_node = re.search(re_max_node, content)
val_perc = float(match_perc.group(1)) if match_perc else None
val_d = float(match_d.group(1)) if match_d else None
val_node = float(match_node.group(1)) if match_node else None
return val_perc, val_d, val_node
def process_all_structures(self) -> List[StructureResult]:
"""
处理所有结构,提取计算结果
Returns:
StructureResult 列表
"""
results = []
if not os.path.exists(self.data_dir):
print(f"⚠️ 数据目录不存在: {self.data_dir}")
return results
print("\n正在提取计算结果...")
# 遍历阴离子目录
for anion_key in os.listdir(self.data_dir):
anion_dir = os.path.join(self.data_dir, anion_key)
if not os.path.isdir(anion_dir):
continue
# 遍历结构目录
for struct_name in os.listdir(anion_dir):
struct_dir = os.path.join(anion_dir, struct_name)
if not os.path.isdir(struct_dir):
continue
# 查找 log.txt
log_path = os.path.join(struct_dir, "log.txt")
# 提取参数
perc, min_d, max_node = self.extract_from_log(log_path)
result = StructureResult(
structure_name=struct_name,
anion_type=anion_key,
work_dir=struct_dir,
percolation_diameter=perc,
min_d=min_d,
max_node_length=max_node
)
results.append(result)
print(f" 共处理 {len(results)} 个结构")
return results
def apply_filter(
self,
results: List[StructureResult],
criteria: FilterCriteria
) -> List[StructureResult]:
"""
应用筛选条件
Args:
results: 结构结果列表
criteria: 筛选条件
Returns:
更新后的结果列表(包含筛选状态)
"""
print("\n应用筛选条件...")
print(f" 最小渗透直径: {criteria.min_percolation_diameter} Å")
print(f" 最小 d 值: {criteria.min_d_value}")
print(f" 最大节点长度: {criteria.max_node_length} Å")
passed_count = 0
for result in results:
# 检查是否有有效数据
if result.percolation_diameter is None or result.min_d is None:
result.passed_filter = False
result.filter_reason = "数据缺失"
continue
# 检查渗透直径
if result.percolation_diameter < criteria.min_percolation_diameter:
result.passed_filter = False
result.filter_reason = f"渗透直径 {result.percolation_diameter:.3f} < {criteria.min_percolation_diameter}"
continue
# 检查 d 值
if result.min_d < criteria.min_d_value:
result.passed_filter = False
result.filter_reason = f"d 值 {result.min_d:.3f} < {criteria.min_d_value}"
continue
# 检查节点长度(如果有数据)
if result.max_node_length is not None:
if result.max_node_length > criteria.max_node_length:
result.passed_filter = False
result.filter_reason = f"节点长度 {result.max_node_length:.3f} > {criteria.max_node_length}"
continue
# 通过所有筛选
result.passed_filter = True
result.filter_reason = "通过"
passed_count += 1
print(f" 通过筛选: {passed_count}/{len(results)}")
return results
def save_summary_csv(
self,
results: List[StructureResult],
output_path: str = None
) -> str:
"""
保存汇总 CSV 文件
Args:
results: 结构结果列表
output_path: 输出路径(默认 workspace/results/summary.csv
Returns:
CSV 文件路径
"""
if output_path is None:
os.makedirs(self.output_dir, exist_ok=True)
output_path = os.path.join(self.output_dir, "summary.csv")
# 构建数据
data = []
for r in results:
data.append({
'Structure': r.structure_name,
'Anion_Type': r.anion_type,
'Percolation_Diameter_A': r.percolation_diameter,
'Min_d': r.min_d,
'Max_Node_Length_A': r.max_node_length,
'Passed_Filter': r.passed_filter,
'Filter_Reason': r.filter_reason
})
df = pd.DataFrame(data)
# 按阴离子类型和结构名排序
df = df.sort_values(['Anion_Type', 'Structure'])
# 保存
os.makedirs(os.path.dirname(output_path), exist_ok=True)
df.to_csv(output_path, index=False)
print(f"\n汇总 CSV 已保存: {output_path}")
return output_path
def save_anion_csv(
self,
results: List[StructureResult],
output_dir: str = None
) -> List[str]:
"""
按阴离子类型分别保存 CSV 文件
Args:
results: 结构结果列表
output_dir: 输出目录
Returns:
生成的 CSV 文件路径列表
"""
if output_dir is None:
output_dir = self.output_dir
# 按阴离子类型分组
anion_groups: Dict[str, List[StructureResult]] = {}
for r in results:
if r.anion_type not in anion_groups:
anion_groups[r.anion_type] = []
anion_groups[r.anion_type].append(r)
csv_files = []
for anion_type, group_results in anion_groups.items():
# 构建数据
data = []
for r in group_results:
data.append({
'Structure': r.structure_name,
'Percolation_Diameter_A': r.percolation_diameter,
'Min_d': r.min_d,
'Max_Node_Length_A': r.max_node_length,
'Passed_Filter': r.passed_filter,
'Filter_Reason': r.filter_reason
})
df = pd.DataFrame(data)
df = df.sort_values('Structure')
# 保存到对应目录
anion_output_dir = os.path.join(output_dir, anion_type)
os.makedirs(anion_output_dir, exist_ok=True)
csv_path = os.path.join(anion_output_dir, f"{anion_type}.csv")
df.to_csv(csv_path, index=False)
csv_files.append(csv_path)
print(f" {anion_type}: {len(group_results)} 个结构 -> {csv_path}")
return csv_files
def copy_passed_structures(
self,
results: List[StructureResult],
output_dir: str = None
) -> int:
"""
将通过筛选的结构复制到新文件夹
Args:
results: 结构结果列表
output_dir: 输出目录(默认 workspace/passed
Returns:
复制的结构数量
"""
if output_dir is None:
output_dir = os.path.join(self.workspace_path, "passed")
passed_results = [r for r in results if r.passed_filter]
if not passed_results:
print("\n没有通过筛选的结构")
return 0
print(f"\n正在复制 {len(passed_results)} 个通过筛选的结构...")
copied = 0
for r in passed_results:
# 目标目录passed/阴离子类型/结构名/
dst_dir = os.path.join(output_dir, r.anion_type, r.structure_name)
try:
# 如果目标已存在,先删除
if os.path.exists(dst_dir):
shutil.rmtree(dst_dir)
# 复制整个目录
shutil.copytree(r.work_dir, dst_dir)
copied += 1
except Exception as e:
print(f" ⚠️ 复制失败 {r.structure_name}: {e}")
print(f" 已复制 {copied} 个结构到: {output_dir}")
return copied
def process_and_filter(
self,
criteria: FilterCriteria = None,
save_csv: bool = True,
copy_passed: bool = True
) -> Tuple[List[StructureResult], Dict]:
"""
完整的处理流程:提取数据 -> 筛选 -> 保存 CSV -> 复制通过的结构
Args:
criteria: 筛选条件(如果为 None则不筛选
save_csv: 是否保存 CSV
copy_passed: 是否复制通过筛选的结构
Returns:
(结果列表, 统计信息字典)
"""
# 1. 提取所有结构的计算结果
results = self.process_all_structures()
if not results:
return results, {'total': 0, 'passed': 0, 'failed': 0}
# 2. 应用筛选条件
if criteria is not None:
results = self.apply_filter(results, criteria)
# 3. 保存 CSV
if save_csv:
print("\n保存结果 CSV...")
self.save_summary_csv(results)
self.save_anion_csv(results)
# 4. 复制通过筛选的结构
if copy_passed and criteria is not None:
self.copy_passed_structures(results)
# 统计
stats = {
'total': len(results),
'passed': sum(1 for r in results if r.passed_filter),
'failed': sum(1 for r in results if not r.passed_filter),
'missing_data': sum(1 for r in results if r.filter_reason == "数据缺失")
}
return results, stats
def print_summary(self, results: List[StructureResult], stats: Dict):
"""打印结果摘要"""
print("\n" + "=" * 60)
print("【计算结果摘要】")
print("=" * 60)
print(f" 总结构数: {stats['total']}")
print(f" 通过筛选: {stats['passed']}")
print(f" 未通过筛选: {stats['failed']}")
print(f" 数据缺失: {stats.get('missing_data', 0)}")
# 按阴离子类型统计
anion_stats: Dict[str, Dict] = {}
for r in results:
if r.anion_type not in anion_stats:
anion_stats[r.anion_type] = {'total': 0, 'passed': 0}
anion_stats[r.anion_type]['total'] += 1
if r.passed_filter:
anion_stats[r.anion_type]['passed'] += 1
print("\n 按阴离子类型:")
for anion, s in sorted(anion_stats.items()):
print(f" {anion}: {s['passed']}/{s['total']} 通过")
print("=" * 60)

View File

@@ -0,0 +1,288 @@
"""
工作区管理器:管理计算工作区的创建和软链接
"""
import os
import shutil
from pathlib import Path
from typing import Dict, List, Optional, Set, Tuple
from dataclasses import dataclass, field
@dataclass
class WorkspaceInfo:
"""工作区信息"""
workspace_path: str
data_dir: str # workspace/data
tool_dir: str # tool 目录
target_cation: str
target_anions: Set[str]
# 统计信息
total_structures: int = 0
anion_counts: Dict[str, int] = field(default_factory=dict)
linked_structures: int = 0 # 已创建软链接的结构数
class WorkspaceManager:
"""
工作区管理器
负责:
1. 检测现有工作区
2. 创建软链接yaml 配置文件和计算脚本放在每个结构目录下)
3. 准备计算任务
"""
# 支持的阴离子及其配置文件
SUPPORTED_ANIONS = {'O', 'S', 'Cl', 'Br'}
def __init__(
self,
workspace_path: str = "workspace",
tool_dir: str = "tool",
target_cation: str = "Li"
):
"""
初始化工作区管理器
Args:
workspace_path: 工作区根目录
tool_dir: 工具目录(包含 yaml 配置和计算脚本)
target_cation: 目标阳离子
"""
self.workspace_path = os.path.abspath(workspace_path)
self.tool_dir = os.path.abspath(tool_dir)
self.target_cation = target_cation
# 数据目录
self.data_dir = os.path.join(self.workspace_path, "data")
def check_existing_workspace(self) -> Optional[WorkspaceInfo]:
"""
检查现有工作区
Returns:
WorkspaceInfo 如果存在,否则 None
"""
if not os.path.exists(self.data_dir):
return None
# 扫描数据目录
anion_counts = {}
total = 0
linked = 0
for item in os.listdir(self.data_dir):
item_path = os.path.join(self.data_dir, item)
if os.path.isdir(item_path):
# 可能是阴离子目录(如 O, S, O+S
# 统计其中的结构数量
count = 0
for sub_item in os.listdir(item_path):
sub_path = os.path.join(item_path, sub_item)
if os.path.isdir(sub_path):
# 检查是否包含 CIF 文件
cif_files = [f for f in os.listdir(sub_path) if f.endswith('.cif')]
if cif_files:
count += 1
# 检查是否已有软链接
yaml_files = [f for f in os.listdir(sub_path) if f.endswith('.yaml')]
if yaml_files:
linked += 1
if count > 0:
anion_counts[item] = count
total += count
if total == 0:
return None
return WorkspaceInfo(
workspace_path=self.workspace_path,
data_dir=self.data_dir,
tool_dir=self.tool_dir,
target_cation=self.target_cation,
target_anions=set(anion_counts.keys()),
total_structures=total,
anion_counts=anion_counts,
linked_structures=linked
)
def setup_workspace(
self,
target_anions: Set[str] = None,
force_relink: bool = False
) -> WorkspaceInfo:
"""
设置工作区:在每个结构目录下创建软链接
软链接规则:
- yaml 文件:使用与阴离子目录同名的 yaml如 O 目录用 O.yamlCl+O 目录用 Cl+O.yaml
- python 脚本analyze_voronoi_nodes.py
Args:
target_anions: 目标阴离子集合
force_relink: 是否强制重新创建软链接
Returns:
WorkspaceInfo
"""
if target_anions is None:
target_anions = self.SUPPORTED_ANIONS
# 确保数据目录存在
if not os.path.exists(self.data_dir):
raise FileNotFoundError(f"数据目录不存在: {self.data_dir}")
# 获取计算脚本路径
analyze_script = os.path.join(self.tool_dir, "analyze_voronoi_nodes.py")
if not os.path.exists(analyze_script):
raise FileNotFoundError(f"计算脚本不存在: {analyze_script}")
anion_counts = {}
total = 0
linked = 0
print("\n正在设置工作区软链接...")
# 遍历数据目录中的阴离子子目录
for anion_key in os.listdir(self.data_dir):
anion_dir = os.path.join(self.data_dir, anion_key)
if not os.path.isdir(anion_dir):
continue
# 确定使用哪个 yaml 配置文件
# 使用与阴离子目录同名的 yaml 文件(如 O.yaml, Cl+O.yaml
yaml_name = f"{anion_key}.yaml"
yaml_source = os.path.join(self.tool_dir, self.target_cation, yaml_name)
if not os.path.exists(yaml_source):
print(f" ⚠️ 配置文件不存在: {yaml_source}")
continue
# 统计并处理该阴离子目录下的所有结构
count = 0
for struct_name in os.listdir(anion_dir):
struct_dir = os.path.join(anion_dir, struct_name)
if not os.path.isdir(struct_dir):
continue
# 检查是否包含 CIF 文件
cif_files = [f for f in os.listdir(struct_dir) if f.endswith('.cif')]
if not cif_files:
continue
count += 1
# 在结构目录下创建软链接
yaml_link = os.path.join(struct_dir, yaml_name)
script_link = os.path.join(struct_dir, "analyze_voronoi_nodes.py")
# 创建 yaml 软链接
if os.path.exists(yaml_link) or os.path.islink(yaml_link):
if force_relink:
os.remove(yaml_link)
os.symlink(yaml_source, yaml_link)
linked += 1
else:
os.symlink(yaml_source, yaml_link)
linked += 1
# 创建计算脚本软链接
if os.path.exists(script_link) or os.path.islink(script_link):
if force_relink:
os.remove(script_link)
os.symlink(analyze_script, script_link)
else:
os.symlink(analyze_script, script_link)
if count > 0:
anion_counts[anion_key] = count
total += count
print(f"{anion_key}: {count} 个结构, 配置 -> {yaml_name}")
print(f"\n 总计: {total} 个结构, 新建软链接: {linked}")
return WorkspaceInfo(
workspace_path=self.workspace_path,
data_dir=self.data_dir,
tool_dir=self.tool_dir,
target_cation=self.target_cation,
target_anions=set(anion_counts.keys()),
total_structures=total,
anion_counts=anion_counts,
linked_structures=linked
)
def get_computation_tasks(
self,
workspace_info: WorkspaceInfo = None
) -> List[Dict]:
"""
获取所有计算任务
Returns:
任务列表,每个任务包含:
- cif_path: CIF 文件路径
- yaml_name: YAML 配置文件名(如 O.yaml
- work_dir: 工作目录(结构目录)
- anion_type: 阴离子类型
- structure_name: 结构名称
"""
if workspace_info is None:
workspace_info = self.check_existing_workspace()
if workspace_info is None:
return []
tasks = []
for anion_key in workspace_info.anion_counts.keys():
anion_dir = os.path.join(self.data_dir, anion_key)
yaml_name = f"{anion_key}.yaml"
# 遍历该阴离子目录下的所有结构
for struct_name in os.listdir(anion_dir):
struct_dir = os.path.join(anion_dir, struct_name)
if not os.path.isdir(struct_dir):
continue
# 查找 CIF 文件
cif_files = [f for f in os.listdir(struct_dir) if f.endswith('.cif')]
# 检查是否有 yaml 软链接
yaml_path = os.path.join(struct_dir, yaml_name)
if not os.path.exists(yaml_path):
continue
for cif_file in cif_files:
cif_path = os.path.join(struct_dir, cif_file)
tasks.append({
'cif_path': cif_path,
'yaml_name': yaml_name,
'work_dir': struct_dir,
'anion_type': anion_key,
'structure_name': struct_name,
'cif_name': cif_file
})
return tasks
def print_workspace_summary(self, workspace_info: WorkspaceInfo):
"""打印工作区摘要"""
print("\n" + "=" * 60)
print("【工作区摘要】")
print("=" * 60)
print(f" 工作区路径: {workspace_info.workspace_path}")
print(f" 数据目录: {workspace_info.data_dir}")
print(f" 目标阳离子: {workspace_info.target_cation}")
print(f" 总结构数: {workspace_info.total_structures}")
print(f" 已配置软链接: {workspace_info.linked_structures}")
print()
print(" 阴离子分布:")
for anion, count in sorted(workspace_info.anion_counts.items()):
print(f" - {anion}: {count} 个结构")
print("=" * 60)

View File

@@ -0,0 +1,446 @@
"""
Zeo++ 计算执行器:使用 SLURM 作业数组高效调度大量计算任务
"""
import os
import subprocess
import time
import json
import tempfile
from typing import List, Dict, Optional, Callable, Any
from dataclasses import dataclass, field
from enum import Enum
import threading
from ..core.progress import ProgressManager
@dataclass
class ZeoConfig:
"""Zeo++ 计算配置"""
# 环境配置
conda_env: str = "/cluster/home/koko125/anaconda3/envs/zeo"
# SLURM 配置
partition: str = "cpu"
time_limit: str = "2:00:00" # 单个任务时间限制
memory_per_task: str = "4G"
# 作业数组配置
max_array_size: int = 1000 # SLURM 作业数组最大大小
max_concurrent: int = 50 # 最大并发任务数
# 轮询配置
poll_interval: float = 5.0 # 状态检查间隔(秒)
# 过滤器配置
filters: List[str] = field(default_factory=lambda: [
"Ordered", "PropOxi", "VoroPerco", "Coulomb", "VoroBV", "VoroInfo", "MergeSite"
])
@dataclass
class ZeoTaskResult:
"""单个任务结果"""
task_id: int
structure_name: str
cif_path: str
success: bool
output_files: List[str] = field(default_factory=list)
error_message: str = ""
duration: float = 0.0
class ZeoExecutor:
"""
Zeo++ 计算执行器
使用 SLURM 作业数组高效调度大量 Voronoi 分析任务
"""
def __init__(self, config: ZeoConfig = None):
self.config = config or ZeoConfig()
self.progress_manager = None
self._stop_event = threading.Event()
def run_batch(
self,
tasks: List[Dict],
output_dir: str = None,
desc: str = "Zeo++ 计算"
) -> List[ZeoTaskResult]:
"""
批量执行 Zeo++ 计算
Args:
tasks: 任务列表,每个任务包含 cif_path, yaml_path, work_dir 等
output_dir: SLURM 日志输出目录
desc: 进度条描述
Returns:
ZeoTaskResult 列表
"""
if not tasks:
print("⚠️ 没有任务需要执行")
return []
total = len(tasks)
# 创建输出目录
if output_dir is None:
output_dir = os.path.join(os.getcwd(), "slurm_logs")
os.makedirs(output_dir, exist_ok=True)
print(f"\n{'='*60}")
print(f"【Zeo++ 批量计算】")
print(f"{'='*60}")
print(f" 总任务数: {total}")
print(f" Conda环境: {self.config.conda_env}")
print(f" SLURM分区: {self.config.partition}")
print(f" 最大并发: {self.config.max_concurrent}")
print(f" 日志目录: {output_dir}")
print(f"{'='*60}\n")
# 保存任务列表到文件
tasks_file = os.path.join(output_dir, "tasks.json")
with open(tasks_file, 'w') as f:
json.dump(tasks, f, indent=2)
# 生成并提交作业数组
if total <= self.config.max_array_size:
# 单个作业数组
return self._submit_array_job(tasks, output_dir, desc)
else:
# 分批提交多个作业数组
return self._submit_batched_arrays(tasks, output_dir, desc)
def _submit_array_job(
self,
tasks: List[Dict],
output_dir: str,
desc: str
) -> List[ZeoTaskResult]:
"""提交单个作业数组"""
total = len(tasks)
# 保存任务列表
tasks_file = os.path.join(output_dir, "tasks.json")
with open(tasks_file, 'w') as f:
json.dump(tasks, f, indent=2)
# 生成作业脚本
script_content = self._generate_array_script(
tasks_file=tasks_file,
output_dir=output_dir,
array_range=f"0-{total-1}%{self.config.max_concurrent}"
)
script_path = os.path.join(output_dir, "submit_array.sh")
with open(script_path, 'w') as f:
f.write(script_content)
os.chmod(script_path, 0o755)
print(f"生成作业脚本: {script_path}")
# 提交作业
result = subprocess.run(
['sbatch', script_path],
capture_output=True,
text=True
)
if result.returncode != 0:
print(f"❌ 作业提交失败: {result.stderr}")
return [ZeoTaskResult(
task_id=i,
structure_name=t.get('structure_name', ''),
cif_path=t.get('cif_path', ''),
success=False,
error_message=f"提交失败: {result.stderr}"
) for i, t in enumerate(tasks)]
# 提取作业 ID
job_id = result.stdout.strip().split()[-1]
print(f"✓ 作业已提交: {job_id}")
print(f" 作业数组范围: 0-{total-1}")
print(f" 最大并发: {self.config.max_concurrent}")
# 监控作业进度
return self._monitor_array_job(job_id, tasks, output_dir, desc)
def _submit_batched_arrays(
self,
tasks: List[Dict],
output_dir: str,
desc: str
) -> List[ZeoTaskResult]:
"""分批提交多个作业数组"""
total = len(tasks)
batch_size = self.config.max_array_size
num_batches = (total + batch_size - 1) // batch_size
print(f"任务数超过作业数组限制,分 {num_batches} 批提交")
all_results = []
for batch_idx in range(num_batches):
start_idx = batch_idx * batch_size
end_idx = min(start_idx + batch_size, total)
batch_tasks = tasks[start_idx:end_idx]
batch_output_dir = os.path.join(output_dir, f"batch_{batch_idx}")
os.makedirs(batch_output_dir, exist_ok=True)
print(f"\n--- 批次 {batch_idx + 1}/{num_batches} ---")
print(f"任务范围: {start_idx} - {end_idx - 1}")
batch_results = self._submit_array_job(
batch_tasks,
batch_output_dir,
f"{desc} (批次 {batch_idx + 1}/{num_batches})"
)
# 调整任务 ID
for r in batch_results:
r.task_id += start_idx
all_results.extend(batch_results)
return all_results
def _generate_array_script(
self,
tasks_file: str,
output_dir: str,
array_range: str
) -> str:
"""生成 SLURM 作业数组脚本"""
# 获取项目根目录
project_root = os.getcwd()
script = f"""#!/bin/bash
#SBATCH --job-name=zeo_array
#SBATCH --partition={self.config.partition}
#SBATCH --array={array_range}
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=1
#SBATCH --mem={self.config.memory_per_task}
#SBATCH --time={self.config.time_limit}
#SBATCH --output={output_dir}/task_%a.out
#SBATCH --error={output_dir}/task_%a.err
# ============================================
# Zeo++ Voronoi 分析 - 作业数组
# ============================================
echo "===== 任务信息 ====="
echo "作业ID: $SLURM_JOB_ID"
echo "数组任务ID: $SLURM_ARRAY_TASK_ID"
echo "节点: $SLURM_NODELIST"
echo "开始时间: $(date)"
echo "===================="
# ============ 环境初始化 ============
# 加载 bashrc
if [ -f ~/.bashrc ]; then
source ~/.bashrc
fi
# 初始化 Conda
if [ -f ~/anaconda3/etc/profile.d/conda.sh ]; then
source ~/anaconda3/etc/profile.d/conda.sh
elif [ -f /opt/anaconda3/etc/profile.d/conda.sh ]; then
source /opt/anaconda3/etc/profile.d/conda.sh
fi
# 激活 Zeo++ 环境
conda activate {self.config.conda_env}
echo ""
echo "===== 环境检查 ====="
echo "Conda环境: $CONDA_DEFAULT_ENV"
echo "Python路径: $(which python)"
echo "===================="
echo ""
# ============ 读取任务信息 ============
TASKS_FILE="{tasks_file}"
TASK_ID=$SLURM_ARRAY_TASK_ID
# 使用 Python 解析任务
TASK_INFO=$(python3 -c "
import json
with open('$TASKS_FILE', 'r') as f:
tasks = json.load(f)
if $TASK_ID < len(tasks):
task = tasks[$TASK_ID]
print(task['work_dir'])
print(task['yaml_name'])
else:
print('ERROR')
")
WORK_DIR=$(echo "$TASK_INFO" | sed -n '1p')
YAML_NAME=$(echo "$TASK_INFO" | sed -n '2p')
if [ "$WORK_DIR" == "ERROR" ]; then
echo "错误: 任务ID $TASK_ID 超出范围"
exit 1
fi
echo "工作目录: $WORK_DIR"
echo "配置文件: $YAML_NAME"
echo ""
# ============ 执行计算 ============
cd "$WORK_DIR"
echo "开始 Voronoi 分析..."
# 软链接已在工作目录下,直接使用相对路径
# 将输出重定向到 log.txt 以便后续提取结果
python analyze_voronoi_nodes.py *.cif -i "$YAML_NAME" > log.txt 2>&1
EXIT_CODE=$?
# 显示日志内容(用于调试)
echo ""
echo "===== 计算日志 ====="
cat log.txt
echo "===================="
# ============ 完成 ============
echo ""
echo "===== 任务完成 ====="
echo "结束时间: $(date)"
echo "退出代码: $EXIT_CODE"
# 写入状态文件
if [ $EXIT_CODE -eq 0 ]; then
echo "SUCCESS" > "{output_dir}/status_$TASK_ID.txt"
else
echo "FAILED" > "{output_dir}/status_$TASK_ID.txt"
fi
echo "===================="
exit $EXIT_CODE
"""
return script
def _monitor_array_job(
self,
job_id: str,
tasks: List[Dict],
output_dir: str,
desc: str
) -> List[ZeoTaskResult]:
"""监控作业数组进度"""
total = len(tasks)
self.progress_manager = ProgressManager(total, desc)
self.progress_manager.start()
results = [None] * total
completed = set()
print(f"\n监控作业进度 (每 {self.config.poll_interval} 秒检查一次)...")
print("按 Ctrl+C 可中断监控(作业将继续在后台运行)\n")
try:
while len(completed) < total:
time.sleep(self.config.poll_interval)
# 检查状态文件
for i in range(total):
if i in completed:
continue
status_file = os.path.join(output_dir, f"status_{i}.txt")
if os.path.exists(status_file):
with open(status_file, 'r') as f:
status = f.read().strip()
task = tasks[i]
success = (status == "SUCCESS")
# 收集输出文件
output_files = []
if success:
work_dir = task['work_dir']
for f in os.listdir(work_dir):
if f.endswith(('.cif', '.csv')) and f != task['cif_name']:
output_files.append(os.path.join(work_dir, f))
results[i] = ZeoTaskResult(
task_id=i,
structure_name=task.get('structure_name', ''),
cif_path=task.get('cif_path', ''),
success=success,
output_files=output_files
)
completed.add(i)
self.progress_manager.update(success=success)
self.progress_manager.display()
# 检查作业是否还在运行
if not self._is_job_running(job_id) and len(completed) < total:
# 作业已结束但有任务未完成
print(f"\n⚠️ 作业已结束,但有 {total - len(completed)} 个任务未完成")
break
except KeyboardInterrupt:
print("\n\n⚠️ 监控已中断,作业将继续在后台运行")
print(f" 可使用 'squeue -j {job_id}' 查看作业状态")
print(f" 可使用 'scancel {job_id}' 取消作业")
self.progress_manager.finish()
# 填充未完成的任务
for i in range(total):
if results[i] is None:
task = tasks[i]
results[i] = ZeoTaskResult(
task_id=i,
structure_name=task.get('structure_name', ''),
cif_path=task.get('cif_path', ''),
success=False,
error_message="任务未完成或状态未知"
)
return results
def _is_job_running(self, job_id: str) -> bool:
"""检查作业是否还在运行"""
try:
result = subprocess.run(
['squeue', '-j', job_id, '-h'],
capture_output=True,
text=True,
timeout=10
)
return bool(result.stdout.strip())
except Exception:
return False
def print_results_summary(self, results: List[ZeoTaskResult]):
"""打印结果摘要"""
total = len(results)
success = sum(1 for r in results if r.success)
failed = total - success
print("\n" + "=" * 60)
print("【计算结果摘要】")
print("=" * 60)
print(f" 总任务数: {total}")
print(f" 成功: {success} ({100*success/total:.1f}%)")
print(f" 失败: {failed} ({100*failed/total:.1f}%)")
if failed > 0 and failed <= 10:
print("\n 失败的任务:")
for r in results:
if not r.success:
print(f" - {r.structure_name}: {r.error_message}")
elif failed > 10:
print(f"\n 失败任务过多,请检查日志文件")
print("=" * 60)

View File

@@ -0,0 +1,18 @@
"""
核心模块:调度器、执行器和进度管理
"""
from .scheduler import ParallelScheduler, ResourceConfig, ExecutionMode as SchedulerMode
from .executor import TaskExecutor, ExecutorConfig, ExecutionMode, TaskResult, create_executor
from .progress import ProgressManager
__all__ = [
'ParallelScheduler',
'ResourceConfig',
'SchedulerMode',
'TaskExecutor',
'ExecutorConfig',
'ExecutionMode',
'TaskResult',
'create_executor',
'ProgressManager',
]

431
src/core/executor.py Normal file
View File

@@ -0,0 +1,431 @@
"""
任务执行器:支持本地执行和 SLURM 直接提交
不生成脚本文件,直接在 Python 中管理任务
"""
import os
import subprocess
import time
import json
from typing import List, Callable, Any, Optional, Dict, Tuple
from multiprocessing import Pool, cpu_count
from dataclasses import dataclass, field
from enum import Enum
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading
from .progress import ProgressManager
class ExecutionMode(Enum):
"""执行模式"""
LOCAL = "local" # 本地多进程
SLURM_DIRECT = "slurm" # SLURM 直接提交(不生成脚本)
@dataclass
class ExecutorConfig:
"""执行器配置"""
mode: ExecutionMode = ExecutionMode.LOCAL
max_workers: int = 4
conda_env: str = "/cluster/home/koko125/anaconda3/envs/screen"
partition: str = "cpu"
time_limit: str = "7-00:00:00"
memory_per_task: str = "4G"
# SLURM 相关
poll_interval: float = 2.0 # 轮询间隔(秒)
max_concurrent_jobs: int = 50 # 最大并发作业数
@dataclass
class TaskResult:
"""任务结果"""
task_id: Any
success: bool
result: Any = None
error: str = None
duration: float = 0.0
class TaskExecutor:
"""
任务执行器
支持两种模式:
1. LOCAL: 本地多进程执行
2. SLURM_DIRECT: 直接提交 SLURM 作业,实时监控进度
"""
def __init__(self, config: ExecutorConfig = None):
self.config = config or ExecutorConfig()
self.progress_manager = None
self._stop_event = threading.Event()
@staticmethod
def detect_environment() -> Dict[str, Any]:
"""检测运行环境"""
env_info = {
'hostname': os.uname().nodename,
'total_cores': cpu_count(),
'has_slurm': False,
'slurm_partitions': [],
'conda_env': os.environ.get('CONDA_PREFIX', ''),
}
# 检测 SLURM
try:
result = subprocess.run(
['sinfo', '-h', '-o', '%P %a %c %D'],
capture_output=True, text=True, timeout=5
)
if result.returncode == 0:
env_info['has_slurm'] = True
lines = result.stdout.strip().split('\n')
for line in lines:
parts = line.split()
if len(parts) >= 4:
partition = parts[0].rstrip('*')
avail = parts[1]
if avail == 'up':
env_info['slurm_partitions'].append(partition)
except Exception:
pass
return env_info
def run(
self,
tasks: List[Any],
worker_func: Callable,
desc: str = "Processing"
) -> List[TaskResult]:
"""
执行任务
Args:
tasks: 任务列表
worker_func: 工作函数,接收单个任务,返回结果
desc: 进度条描述
Returns:
TaskResult 列表
"""
if self.config.mode == ExecutionMode.LOCAL:
return self._run_local(tasks, worker_func, desc)
elif self.config.mode == ExecutionMode.SLURM_DIRECT:
return self._run_slurm_direct(tasks, worker_func, desc)
else:
raise ValueError(f"不支持的执行模式: {self.config.mode}")
def _run_local(
self,
tasks: List[Any],
worker_func: Callable,
desc: str
) -> List[TaskResult]:
"""本地多进程执行"""
total = len(tasks)
num_workers = min(self.config.max_workers, total)
print(f"\n{'='*60}")
print(f"本地执行配置:")
print(f" 总任务数: {total}")
print(f" Worker数: {num_workers}")
print(f"{'='*60}\n")
self.progress_manager = ProgressManager(total, desc)
self.progress_manager.start()
results = []
if num_workers == 1:
# 单进程执行
for i, task in enumerate(tasks):
start_time = time.time()
try:
result = worker_func(task)
duration = time.time() - start_time
results.append(TaskResult(
task_id=i,
success=True,
result=result,
duration=duration
))
self.progress_manager.update(success=True)
except Exception as e:
duration = time.time() - start_time
results.append(TaskResult(
task_id=i,
success=False,
error=str(e),
duration=duration
))
self.progress_manager.update(success=False)
self.progress_manager.display()
else:
# 多进程执行
with Pool(processes=num_workers) as pool:
for i, result in enumerate(pool.imap_unordered(worker_func, tasks)):
if result is not None:
results.append(TaskResult(
task_id=i,
success=True,
result=result
))
self.progress_manager.update(success=True)
else:
results.append(TaskResult(
task_id=i,
success=False,
error="Worker returned None"
))
self.progress_manager.update(success=False)
self.progress_manager.display()
self.progress_manager.finish()
return results
def _run_slurm_direct(
self,
tasks: List[Any],
worker_func: Callable,
desc: str
) -> List[TaskResult]:
"""
SLURM 直接提交模式
注意:对于数据库分析这类快速任务,建议使用本地多进程模式
SLURM 模式更适合耗时的计算任务(如 Zeo++ 分析)
这里回退到本地模式,因为 srun 在登录节点直接调用效率不高
"""
print("\n⚠️ 注意:数据库分析阶段自动使用本地多进程模式")
print(" SLURM 模式将在后续耗时计算步骤中使用")
# 回退到本地模式
return self._run_local(tasks, worker_func, desc)
class SlurmJobManager:
"""
SLURM 作业管理器
用于批量提交和监控 SLURM 作业
"""
def __init__(self, config: ExecutorConfig):
self.config = config
self.active_jobs = {} # job_id -> task_info
def submit_batch(
self,
tasks: List[Tuple[str, str, set]], # (file_path, target_cation, target_anions)
output_dir: str,
desc: str = "Processing"
) -> List[TaskResult]:
"""
批量提交任务到 SLURM
使用 sbatch --wrap 直接提交,不生成脚本文件
"""
total = len(tasks)
os.makedirs(output_dir, exist_ok=True)
print(f"\n{'='*60}")
print(f"SLURM 批量提交:")
print(f" 总任务数: {total}")
print(f" 输出目录: {output_dir}")
print(f" Conda环境: {self.config.conda_env}")
print(f"{'='*60}\n")
progress = ProgressManager(total, desc)
progress.start()
results = []
job_ids = []
# 提交所有任务
for i, task in enumerate(tasks):
file_path, target_cation, target_anions = task
# 构建 Python 命令
anions_str = ','.join(target_anions)
python_cmd = (
f"python -c \""
f"import sys; sys.path.insert(0, '{os.getcwd()}'); "
f"from src.analysis.worker import analyze_single_file; "
f"result = analyze_single_file(('{file_path}', '{target_cation}', set('{anions_str}'.split(',')))); "
f"print('SUCCESS' if result and result.is_valid else 'FAILED')"
f"\""
)
# 构建完整的 bash 命令
bash_cmd = (
f"source {os.path.dirname(self.config.conda_env)}/../../etc/profile.d/conda.sh && "
f"conda activate {self.config.conda_env} && "
f"{python_cmd}"
)
# 使用 sbatch --wrap 提交
sbatch_cmd = [
'sbatch',
'--partition', self.config.partition,
'--ntasks', '1',
'--cpus-per-task', '1',
'--mem', self.config.memory_per_task,
'--time', '01:00:00',
'--output', os.path.join(output_dir, f'task_{i}.out'),
'--error', os.path.join(output_dir, f'task_{i}.err'),
'--wrap', bash_cmd
]
try:
result = subprocess.run(
sbatch_cmd,
capture_output=True,
text=True
)
if result.returncode == 0:
# 提取 job_id
job_id = result.stdout.strip().split()[-1]
job_ids.append((i, job_id, file_path))
self.active_jobs[job_id] = {
'task_index': i,
'file_path': file_path,
'status': 'PENDING'
}
else:
results.append(TaskResult(
task_id=i,
success=False,
error=f"提交失败: {result.stderr}"
))
progress.update(success=False)
progress.display()
except Exception as e:
results.append(TaskResult(
task_id=i,
success=False,
error=str(e)
))
progress.update(success=False)
progress.display()
print(f"\n已提交 {len(job_ids)} 个作业,等待完成...")
# 监控作业状态
while self.active_jobs:
time.sleep(self.config.poll_interval)
# 检查作业状态
completed_jobs = self._check_job_status()
for job_id, status in completed_jobs:
job_info = self.active_jobs.pop(job_id, None)
if job_info:
task_idx = job_info['task_index']
if status == 'COMPLETED':
# 检查输出文件
out_file = os.path.join(output_dir, f'task_{task_idx}.out')
success = False
if os.path.exists(out_file):
with open(out_file, 'r') as f:
content = f.read()
success = 'SUCCESS' in content
results.append(TaskResult(
task_id=task_idx,
success=success,
result=job_info['file_path']
))
progress.update(success=success)
else:
# 作业失败
err_file = os.path.join(output_dir, f'task_{task_idx}.err')
error_msg = status
if os.path.exists(err_file):
with open(err_file, 'r') as f:
error_msg = f.read()[:500] # 只取前500字符
results.append(TaskResult(
task_id=task_idx,
success=False,
error=error_msg
))
progress.update(success=False)
progress.display()
progress.finish()
return results
def _check_job_status(self) -> List[Tuple[str, str]]:
"""检查作业状态,返回已完成的作业列表"""
if not self.active_jobs:
return []
job_ids = list(self.active_jobs.keys())
try:
result = subprocess.run(
['sacct', '-j', ','.join(job_ids), '--format=JobID,State', '--noheader', '--parsable2'],
capture_output=True,
text=True,
timeout=30
)
completed = []
if result.returncode == 0:
for line in result.stdout.strip().split('\n'):
if line:
parts = line.split('|')
if len(parts) >= 2:
job_id = parts[0].split('.')[0] # 去掉 .batch 后缀
status = parts[1]
if job_id in self.active_jobs:
if status in ['COMPLETED', 'FAILED', 'CANCELLED', 'TIMEOUT', 'NODE_FAIL']:
completed.append((job_id, status))
return completed
except Exception:
return []
def create_executor(
mode: str = "local",
max_workers: int = None,
conda_env: str = None,
**kwargs
) -> TaskExecutor:
"""
创建任务执行器的便捷函数
Args:
mode: "local""slurm"
max_workers: 最大工作进程数
conda_env: Conda 环境路径
**kwargs: 其他配置参数
"""
env = TaskExecutor.detect_environment()
if max_workers is None:
max_workers = min(env['total_cores'], 32)
if conda_env is None:
conda_env = env.get('conda_env') or "/cluster/home/koko125/anaconda3/envs/screen"
exec_mode = ExecutionMode.SLURM_DIRECT if mode.lower() == "slurm" else ExecutionMode.LOCAL
config = ExecutorConfig(
mode=exec_mode,
max_workers=max_workers,
conda_env=conda_env,
**kwargs
)
return TaskExecutor(config)

View File

@@ -0,0 +1,562 @@
"""
结构预处理器:扩胞和添加化合价
"""
import os
import re
import yaml
from typing import List, Dict, Optional, Tuple
from dataclasses import dataclass, field
from pymatgen.core.structure import Structure
from pymatgen.core.periodic_table import Specie
from pymatgen.core import Lattice, Species, PeriodicSite
from collections import defaultdict
from fractions import Fraction
from functools import reduce
import math
import random
import spglib
import numpy as np
@dataclass
class ProcessingResult:
"""处理结果"""
input_file: str
output_files: List[str] = field(default_factory=list)
success: bool = False
needs_expansion: bool = False
expansion_factor: int = 1
error_message: str = ""
class StructureProcessor:
"""结构预处理器"""
# 默认化合价配置
DEFAULT_VALENCE_PATH = os.path.join(
os.path.dirname(__file__), '..', '..', 'tool', 'valence_states.yaml'
)
def __init__(
self,
valence_yaml_path: str = None,
calculate_type: str = 'low',
max_expansion_factor: int = 64,
keep_number: int = 3,
target_cation: str = "Li"
):
"""
初始化处理器
Args:
valence_yaml_path: 化合价配置文件路径
calculate_type: 扩胞计算精度 ('high', 'normal', 'low', 'very_low')
max_expansion_factor: 最大扩胞因子
keep_number: 保留的扩胞结构数量
target_cation: 目标阳离子
"""
self.valence_yaml_path = valence_yaml_path or self.DEFAULT_VALENCE_PATH
self.calculate_type = calculate_type
self.max_expansion_factor = max_expansion_factor
self.keep_number = keep_number
self.target_cation = target_cation
self.explict_element = [target_cation, f"{target_cation}+"]
# 加载化合价配置
self.valences = self._load_valences()
def _load_valences(self) -> Dict[str, int]:
"""加载化合价配置"""
if os.path.exists(self.valence_yaml_path):
with open(self.valence_yaml_path, 'r') as f:
return yaml.safe_load(f)
return {}
def process_file(
self,
input_path: str,
output_dir: str,
needs_expansion: bool = False
) -> ProcessingResult:
"""
处理单个CIF文件
Args:
input_path: 输入文件路径
output_dir: 输出目录
needs_expansion: 是否需要扩胞
Returns:
ProcessingResult: 处理结果
"""
result = ProcessingResult(input_file=input_path)
try:
# 读取结构
structure = Structure.from_file(input_path)
base_name = os.path.splitext(os.path.basename(input_path))[0]
# 检查是否需要扩胞
occupation_list = self._process_cif_file(structure)
if occupation_list and needs_expansion:
# 需要扩胞处理
result.needs_expansion = True
output_files = self._expand_and_save(
structure, occupation_list, base_name, output_dir
)
result.output_files = output_files
result.expansion_factor = occupation_list[0].get('denominator', 1) if occupation_list else 1
else:
# 不需要扩胞,直接添加化合价
output_path = os.path.join(output_dir, f"{base_name}.cif")
self._add_oxidation_states(structure)
structure.to(filename=output_path)
result.output_files = [output_path]
result.success = True
except Exception as e:
result.success = False
result.error_message = str(e)
return result
def _process_cif_file(self, structure: Structure) -> List[Dict]:
"""
统计结构中各原子的occupation情况
"""
occupation_dict = defaultdict(list)
split_dict = {}
for i, site in enumerate(structure):
occu = self._get_occu(site.species_string)
if occu != 1.0:
if site.species.chemical_system not in self.explict_element:
occupation_dict[occu].append(i + 1)
# 提取元素名称列表
elements = []
if ':' in site.species_string:
parts = site.species_string.split(',')
for part in parts:
element_with_valence = part.strip().split(':')[0].strip()
element_match = re.match(r'([A-Z][a-z]?)', element_with_valence)
if element_match:
elements.append(element_match.group(1))
else:
element_match = re.match(r'([A-Z][a-z]?)', site.species_string)
if element_match:
elements = [element_match.group(1)]
split_dict[occu] = elements
# 转换为列表格式
occupation_list = [
{
"occupation": occu,
"atom_serial": serials,
"numerator": None,
"denominator": None,
"split": split_dict.get(occu, [])
}
for occu, serials in occupation_dict.items()
]
return occupation_list
def _get_occu(self, s_str: str) -> float:
"""从物种字符串获取占据率"""
if not s_str.strip():
return 1.0
pattern = r'([A-Za-z0-9+-]+):([0-9.]+)'
matches = re.findall(pattern, s_str)
for species, occu in matches:
if species not in self.explict_element:
try:
return float(occu)
except ValueError:
continue
return 1.0
def _calculate_expansion_factor(self, occupation_list: List[Dict]) -> Tuple[int, List[Dict]]:
"""计算扩胞因子"""
if not occupation_list:
return 1, []
precision_limits = {
'high': None,
'normal': 100,
'low': 10,
'very_low': 5
}
limit = precision_limits.get(self.calculate_type)
for entry in occupation_list:
occu = entry["occupation"]
if limit:
fraction = Fraction(occu).limit_denominator(limit)
else:
fraction = Fraction(occu).limit_denominator()
entry["numerator"] = fraction.numerator
entry["denominator"] = fraction.denominator
# 计算最小公倍数
denominators = [entry["denominator"] for entry in occupation_list]
lcm = reduce(lambda a, b: a * b // math.gcd(a, b), denominators, 1)
# 统一分母
for entry in occupation_list:
denominator = entry["denominator"]
entry["numerator"] = entry["numerator"] * (lcm // denominator)
entry["denominator"] = lcm
return lcm, occupation_list
def _expand_and_save(
self,
structure: Structure,
occupation_list: List[Dict],
base_name: str,
output_dir: str
) -> List[str]:
"""扩胞并保存"""
lcm, oc_list = self._calculate_expansion_factor(occupation_list)
if lcm > self.max_expansion_factor:
raise ValueError(f"扩胞因子 {lcm} 超过最大限制 {self.max_expansion_factor}")
# 获取扩胞策略
strategies = self._strategy_divide(structure, lcm)
if not strategies:
raise ValueError("无法找到合适的扩胞策略")
# 生成结构列表
st_list = self._generate_structure_list(structure, oc_list)
output_files = []
keep_number = min(self.keep_number, len(strategies))
for index in range(keep_number):
merged = self._merge_structures(st_list, strategies[index])
# 添加化合价
self._add_oxidation_states(merged)
# 当只保存1个时不加后缀
if keep_number == 1:
output_filename = f"{base_name}.cif"
else:
suffix = "x{}y{}z{}".format(
strategies[index]["x"],
strategies[index]["y"],
strategies[index]["z"]
)
output_filename = f"{base_name}-{suffix}.cif"
output_path = os.path.join(output_dir, output_filename)
merged.to(filename=output_path, fmt="cif")
output_files.append(output_path)
return output_files
def _add_oxidation_states(self, structure: Structure):
"""添加化合价"""
# 检查是否已有化合价
has_oxidation = all(
all(isinstance(sp, Specie) for sp in site.species.keys())
for site in structure.sites
)
if not has_oxidation and self.valences:
structure.add_oxidation_state_by_element(self.valences)
def _strategy_divide(self, structure: Structure, total: int) -> List[Dict]:
"""根据晶体类型确定扩胞策略"""
try:
space_group_info = structure.get_space_group_info()
space_group_symbol = space_group_info[0]
# 获取空间群类型
all_spacegroup_symbols = [spglib.get_spacegroup_type(i) for i in range(1, 531)]
symbol = all_spacegroup_symbols[0]
for symbol_i in all_spacegroup_symbols:
if space_group_symbol == symbol_i.international_short:
symbol = symbol_i
break
space_type = self._typejudge(symbol.number)
if space_type == "Cubic":
return self._factorize_to_three_factors(total, "xyz")
else:
return self._factorize_to_three_factors(total)
except:
return self._factorize_to_three_factors(total)
def _typejudge(self, number: int) -> str:
"""判断晶体类型"""
if number in [1, 2]:
return "Triclinic"
elif 3 <= number <= 15:
return "Monoclinic"
elif 16 <= number <= 74:
return "Orthorhombic"
elif 75 <= number <= 142:
return "Tetragonal"
elif 143 <= number <= 167:
return "Trigonal"
elif 168 <= number <= 194:
return "Hexagonal"
elif 195 <= number <= 230:
return "Cubic"
else:
return "Unknown"
def _factorize_to_three_factors(self, n: int, type_sym: str = None) -> List[Dict]:
"""分解为三个因子"""
factors = []
if type_sym == "xyz":
for x in range(1, n + 1):
if n % x == 0:
remaining_n = n // x
for y in range(1, remaining_n + 1):
if remaining_n % y == 0 and y <= x:
z = remaining_n // y
if z <= y:
factors.append({'x': x, 'y': y, 'z': z})
else:
for x in range(1, n + 1):
if n % x == 0:
remaining_n = n // x
for y in range(1, remaining_n + 1):
if remaining_n % y == 0:
z = remaining_n // y
factors.append({'x': x, 'y': y, 'z': z})
# 排序
def sort_key(item):
return (item['x'] + item['y'] + item['z'], item['z'], item['y'], item['x'])
return sorted(factors, key=sort_key)
def _generate_structure_list(
self,
base_structure: Structure,
occupation_list: List[Dict]
) -> List[Structure]:
"""生成结构列表"""
if not occupation_list:
return [base_structure.copy()]
lcm = occupation_list[0]["denominator"]
structure_list = [base_structure.copy() for _ in range(lcm)]
for entry in occupation_list:
numerator = entry["numerator"]
denominator = entry["denominator"]
atom_indices = entry["atom_serial"]
for atom_idx in atom_indices:
occupancy_dict = self._mark_atoms_randomly(numerator, denominator)
original_site = base_structure.sites[atom_idx - 1]
element = self._get_first_non_explicit_element(original_site.species_string)
for copy_idx, occupy in occupancy_dict.items():
structure_list[copy_idx].remove_sites([atom_idx - 1])
oxi_state = self._extract_oxi_state(original_site.species_string, element)
if len(entry["split"]) == 1:
if occupy:
new_site = PeriodicSite(
species=Species(element, oxi_state),
coords=original_site.frac_coords,
lattice=structure_list[copy_idx].lattice,
to_unit_cell=True,
label=original_site.label
)
else:
species_dict = {Species(self.target_cation, 1.0): 0.0}
new_site = PeriodicSite(
species=species_dict,
coords=original_site.frac_coords,
lattice=structure_list[copy_idx].lattice,
to_unit_cell=True,
label=original_site.label
)
else:
if occupy:
new_site = PeriodicSite(
species=Species(element, oxi_state),
coords=original_site.frac_coords,
lattice=structure_list[copy_idx].lattice,
to_unit_cell=True,
label=original_site.label
)
else:
new_site = PeriodicSite(
species=Species(entry['split'][1], oxi_state),
coords=original_site.frac_coords,
lattice=structure_list[copy_idx].lattice,
to_unit_cell=True,
label=original_site.label
)
structure_list[copy_idx].sites.insert(atom_idx - 1, new_site)
return structure_list
def _mark_atoms_randomly(self, numerator: int, denominator: int) -> Dict[int, int]:
"""随机标记原子"""
if numerator > denominator:
raise ValueError(f"numerator ({numerator}) 不能超过 denominator ({denominator})")
atom_dice = list(range(denominator))
selected_atoms = random.sample(atom_dice, numerator)
return {atom: 1 if atom in selected_atoms else 0 for atom in atom_dice}
def _get_first_non_explicit_element(self, species_str: str) -> str:
"""获取第一个非目标元素"""
if not species_str.strip():
return ""
species_parts = [part.strip() for part in species_str.split(",") if part.strip()]
for part in species_parts:
element_with_charge = part.split(":")[0].strip()
pure_element = ''.join([c for c in element_with_charge if c.isalpha()])
if pure_element not in self.explict_element:
return pure_element
return ""
def _extract_oxi_state(self, species_str: str, element: str) -> int:
"""提取氧化态"""
species_parts = [part.strip() for part in species_str.split(",") if part.strip()]
for part in species_parts:
element_with_charge = part.split(":")[0].strip()
if element in element_with_charge:
charge_part = element_with_charge[len(element):]
if not any(c.isdigit() for c in charge_part):
if "+" in charge_part:
return 1
elif "-" in charge_part:
return -1
else:
return 0
sign = 1
if "-" in charge_part:
sign = -1
digits = ""
for c in charge_part:
if c.isdigit():
digits += c
if digits:
return sign * int(digits)
return 0
def _merge_structures(self, structure_list: List[Structure], merge_dict: Dict) -> Structure:
"""合并结构"""
if not structure_list:
raise ValueError("结构列表不能为空")
ref_lattice = structure_list[0].lattice
total_merge = merge_dict.get("x", 1) * merge_dict.get("y", 1) * merge_dict.get("z", 1)
if len(structure_list) != total_merge:
raise ValueError(f"结构数量({len(structure_list)})与合并次数({total_merge})不匹配")
a, b, c = ref_lattice.abc
alpha, beta, gamma = ref_lattice.angles
new_a = a * merge_dict.get("x", 1)
new_b = b * merge_dict.get("y", 1)
new_c = c * merge_dict.get("z", 1)
new_lattice = Lattice.from_parameters(new_a, new_b, new_c, alpha, beta, gamma)
all_sites = []
for i, structure in enumerate(structure_list):
x_offset = (i // (merge_dict.get("y", 1) * merge_dict.get("z", 1))) % merge_dict.get("x", 1)
y_offset = (i // merge_dict.get("z", 1)) % merge_dict.get("y", 1)
z_offset = i % merge_dict.get("z", 1)
for site in structure:
coords = site.frac_coords.copy()
coords[0] = (coords[0] + x_offset) / merge_dict.get("x", 1)
coords[1] = (coords[1] + y_offset) / merge_dict.get("y", 1)
coords[2] = (coords[2] + z_offset) / merge_dict.get("z", 1)
all_sites.append({"species": site.species, "coords": coords})
return Structure(
new_lattice,
[site["species"] for site in all_sites],
[site["coords"] for site in all_sites]
)
def process_batch(
input_files: List[str],
output_dir: str,
needs_expansion_flags: List[bool] = None,
valence_yaml_path: str = None,
calculate_type: str = 'low',
target_cation: str = "Li",
show_progress: bool = True
) -> List[ProcessingResult]:
"""
批量处理CIF文件
Args:
input_files: 输入文件列表
output_dir: 输出目录
needs_expansion_flags: 是否需要扩胞的标记列表
valence_yaml_path: 化合价配置文件路径
calculate_type: 扩胞计算精度
target_cation: 目标阳离子
show_progress: 是否显示进度
Returns:
处理结果列表
"""
os.makedirs(output_dir, exist_ok=True)
processor = StructureProcessor(
valence_yaml_path=valence_yaml_path,
calculate_type=calculate_type,
target_cation=target_cation
)
if needs_expansion_flags is None:
needs_expansion_flags = [False] * len(input_files)
results = []
total = len(input_files)
for i, (input_file, needs_exp) in enumerate(zip(input_files, needs_expansion_flags)):
if show_progress:
print(f"\r处理进度: {i+1}/{total} - {os.path.basename(input_file)}", end="")
result = processor.process_file(input_file, output_dir, needs_exp)
results.append(result)
if show_progress:
print()
return results