Compare commits
13 Commits
master
...
reconstruc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f78298e803 | ||
| 6ea96c81d6 | |||
| 9bde3e1229 | |||
| 83647c2218 | |||
| 9b36aa10ff | |||
| 2378a3f2a2 | |||
| da8c85b830 | |||
| 72cf0a79e1 | |||
| f27fd3e3ce | |||
| 1fee324c90 | |||
| ae4e7280b4 | |||
| c91998662a | |||
| 6eeb40d222 |
21
.gitignore
vendored
Normal file
21
.gitignore
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
# --- Python 通用忽略 ---
|
||||
__pycache__/
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pyd
|
||||
.Python
|
||||
env/
|
||||
venv/
|
||||
.env
|
||||
.venv/
|
||||
|
||||
# --- VS Code 配置 (可选,建议忽略) ---
|
||||
.vscode/
|
||||
|
||||
# --- JetBrains (你之前的配置) ---
|
||||
.idea/
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
/httpRequests/
|
||||
/dataSources/
|
||||
/dataSources.local.xml
|
||||
8
.idea/.gitignore
generated
vendored
8
.idea/.gitignore
generated
vendored
@@ -1,8 +0,0 @@
|
||||
# 默认忽略的文件
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
# 基于编辑器的 HTTP 客户端请求
|
||||
/httpRequests/
|
||||
# Datasource local storage ignored files
|
||||
/dataSources/
|
||||
/dataSources.local.xml
|
||||
8
.idea/Screen.iml
generated
8
.idea/Screen.iml
generated
@@ -1,8 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="jdk" jdkName="Python 3.11" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
</module>
|
||||
23
.idea/inspectionProfiles/Project_Default.xml
generated
23
.idea/inspectionProfiles/Project_Default.xml
generated
@@ -1,23 +0,0 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<profile version="1.0">
|
||||
<option name="myName" value="Project Default" />
|
||||
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
|
||||
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
||||
<option name="ignoredPackages">
|
||||
<value>
|
||||
<list size="9">
|
||||
<item index="0" class="java.lang.String" itemvalue="torchvision" />
|
||||
<item index="1" class="java.lang.String" itemvalue="torch" />
|
||||
<item index="2" class="java.lang.String" itemvalue="tqdm" />
|
||||
<item index="3" class="java.lang.String" itemvalue="scipy" />
|
||||
<item index="4" class="java.lang.String" itemvalue="h5py" />
|
||||
<item index="5" class="java.lang.String" itemvalue="matplotlib" />
|
||||
<item index="6" class="java.lang.String" itemvalue="numpy" />
|
||||
<item index="7" class="java.lang.String" itemvalue="opencv_python" />
|
||||
<item index="8" class="java.lang.String" itemvalue="Pillow" />
|
||||
</list>
|
||||
</value>
|
||||
</option>
|
||||
</inspection_tool>
|
||||
</profile>
|
||||
</component>
|
||||
6
.idea/inspectionProfiles/profiles_settings.xml
generated
6
.idea/inspectionProfiles/profiles_settings.xml
generated
@@ -1,6 +0,0 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<settings>
|
||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||
<version value="1.0" />
|
||||
</settings>
|
||||
</component>
|
||||
8
.idea/modules.xml
generated
8
.idea/modules.xml
generated
@@ -1,8 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/Screen.iml" filepath="$PROJECT_DIR$/.idea/Screen.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
||||
6
.idea/vcs.xml
generated
6
.idea/vcs.xml
generated
@@ -1,6 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
||||
</component>
|
||||
</project>
|
||||
0
config/settings.yaml
Normal file
0
config/settings.yaml
Normal file
0
config/valence_states.yaml
Normal file
0
config/valence_states.yaml
Normal file
643
main.py
Normal file
643
main.py
Normal file
@@ -0,0 +1,643 @@
|
||||
"""
|
||||
高通量筛选与扩胞项目 - 主入口(支持断点续做)
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
|
||||
|
||||
from src.analysis.database_analyzer import DatabaseAnalyzer
|
||||
from src.analysis.report_generator import ReportGenerator
|
||||
from src.core.executor import TaskExecutor
|
||||
from src.preprocessing.processor import StructureProcessor
|
||||
from src.computation.workspace_manager import WorkspaceManager
|
||||
from src.computation.zeo_executor import ZeoExecutor, ZeoConfig
|
||||
from src.computation.result_processor import ResultProcessor, FilterCriteria
|
||||
|
||||
|
||||
def print_banner():
|
||||
print("""
|
||||
╔═══════════════════════════════════════════════════════════════════╗
|
||||
║ 高通量筛选与扩胞项目 - 数据库分析工具 v2.2 ║
|
||||
║ 支持断点续做与高性能并行计算 ║
|
||||
╚═══════════════════════════════════════════════════════════════════╝
|
||||
""")
|
||||
|
||||
|
||||
def detect_and_show_environment():
|
||||
"""检测并显示环境信息"""
|
||||
env = TaskExecutor.detect_environment()
|
||||
|
||||
print("【运行环境检测】")
|
||||
print(f" 主机名: {env['hostname']}")
|
||||
print(f" 本地CPU核数: {env['total_cores']}")
|
||||
print(f" SLURM集群: {'✅ 可用' if env['has_slurm'] else '❌ 不可用'}")
|
||||
|
||||
if env['has_slurm'] and env['slurm_partitions']:
|
||||
print(f" 可用分区: {', '.join(env['slurm_partitions'])}")
|
||||
|
||||
if env['conda_env']:
|
||||
print(f" 当前Conda: {env['conda_env']}")
|
||||
|
||||
return env
|
||||
|
||||
|
||||
def detect_workflow_status(workspace_path: str = "workspace", target_cation: str = "Li") -> dict:
|
||||
"""
|
||||
检测工作流程状态,确定可以从哪一步继续
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
'has_processed_data': bool, # 是否有扩胞处理后的数据
|
||||
'has_zeo_results': bool, # 是否有 Zeo++ 计算结果
|
||||
'total_structures': int, # 总结构数
|
||||
'structures_with_log': int, # 有 log.txt 的结构数
|
||||
'workspace_info': object, # 工作区信息
|
||||
'ws_manager': object # 工作区管理器
|
||||
}
|
||||
"""
|
||||
status = {
|
||||
'has_processed_data': False,
|
||||
'has_zeo_results': False,
|
||||
'total_structures': 0,
|
||||
'structures_with_log': 0,
|
||||
'workspace_info': None,
|
||||
'ws_manager': None
|
||||
}
|
||||
|
||||
# 创建工作区管理器
|
||||
ws_manager = WorkspaceManager(
|
||||
workspace_path=workspace_path,
|
||||
tool_dir="tool",
|
||||
target_cation=target_cation
|
||||
)
|
||||
status['ws_manager'] = ws_manager
|
||||
|
||||
# 检查是否有处理后的数据
|
||||
data_dir = os.path.join(workspace_path, "data")
|
||||
if os.path.exists(data_dir):
|
||||
existing = ws_manager.check_existing_workspace()
|
||||
if existing and existing.total_structures > 0:
|
||||
status['has_processed_data'] = True
|
||||
status['total_structures'] = existing.total_structures
|
||||
status['workspace_info'] = existing
|
||||
|
||||
# 检查有多少结构有 log.txt(即已完成 Zeo++ 计算)
|
||||
log_count = 0
|
||||
for anion_key in os.listdir(data_dir):
|
||||
anion_dir = os.path.join(data_dir, anion_key)
|
||||
if not os.path.isdir(anion_dir):
|
||||
continue
|
||||
for struct_name in os.listdir(anion_dir):
|
||||
struct_dir = os.path.join(anion_dir, struct_name)
|
||||
if os.path.isdir(struct_dir):
|
||||
log_path = os.path.join(struct_dir, "log.txt")
|
||||
if os.path.exists(log_path):
|
||||
log_count += 1
|
||||
|
||||
status['structures_with_log'] = log_count
|
||||
|
||||
# 如果大部分结构都有 log.txt,认为 Zeo++ 计算已完成
|
||||
if log_count > 0 and log_count >= status['total_structures'] * 0.5:
|
||||
status['has_zeo_results'] = True
|
||||
|
||||
return status
|
||||
|
||||
|
||||
def print_workflow_status(status: dict):
|
||||
"""打印工作流程状态"""
|
||||
print("\n" + "─" * 50)
|
||||
print("【工作流程状态检测】")
|
||||
print("─" * 50)
|
||||
|
||||
if not status['has_processed_data']:
|
||||
print(" Step 1 (扩胞+化合价): ❌ 未完成")
|
||||
print(" Step 2-4 (Zeo++ 计算): ❌ 未完成")
|
||||
print(" Step 5 (结果筛选): ❌ 未完成")
|
||||
else:
|
||||
print(f" Step 1 (扩胞+化合价): ✅ 已完成 ({status['total_structures']} 个结构)")
|
||||
|
||||
if status['has_zeo_results']:
|
||||
print(f" Step 2-4 (Zeo++ 计算): ✅ 已完成 ({status['structures_with_log']}/{status['total_structures']} 有日志)")
|
||||
print(" Step 5 (结果筛选): ⏳ 可执行")
|
||||
else:
|
||||
print(f" Step 2-4 (Zeo++ 计算): ⏳ 可执行 ({status['structures_with_log']}/{status['total_structures']} 有日志)")
|
||||
print(" Step 5 (结果筛选): ❌ 需先完成 Zeo++ 计算")
|
||||
|
||||
print("─" * 50)
|
||||
|
||||
|
||||
def get_user_choice(status: dict) -> str:
|
||||
"""
|
||||
根据工作流程状态获取用户选择
|
||||
|
||||
Returns:
|
||||
'step1': 从头开始(数据库分析 + 扩胞)
|
||||
'step2': 从 Zeo++ 计算开始
|
||||
'step5': 直接进行结果筛选
|
||||
'exit': 退出
|
||||
"""
|
||||
print("\n请选择操作:")
|
||||
|
||||
options = []
|
||||
|
||||
if status['has_zeo_results']:
|
||||
options.append(('5', '直接进行结果筛选 (Step 5)'))
|
||||
options.append(('2', '重新运行 Zeo++ 计算 (Step 2-4)'))
|
||||
options.append(('1', '从头开始 (数据库分析 + 扩胞)'))
|
||||
elif status['has_processed_data']:
|
||||
options.append(('2', '运行 Zeo++ 计算 (Step 2-4)'))
|
||||
options.append(('1', '从头开始 (数据库分析 + 扩胞)'))
|
||||
else:
|
||||
options.append(('1', '从头开始 (数据库分析 + 扩胞)'))
|
||||
|
||||
options.append(('q', '退出'))
|
||||
|
||||
for key, desc in options:
|
||||
print(f" {key}. {desc}")
|
||||
|
||||
choice = input("\n请选择 [默认: " + options[0][0] + "]: ").strip().lower()
|
||||
|
||||
if not choice:
|
||||
choice = options[0][0]
|
||||
|
||||
if choice == 'q':
|
||||
return 'exit'
|
||||
elif choice == '5':
|
||||
return 'step5'
|
||||
elif choice == '2':
|
||||
return 'step2'
|
||||
else:
|
||||
return 'step1'
|
||||
|
||||
|
||||
def run_step1_database_analysis(env: dict, cation: str) -> dict:
|
||||
"""
|
||||
Step 1: 数据库分析与扩胞处理
|
||||
|
||||
Returns:
|
||||
处理参数字典,如果用户取消则返回 None
|
||||
"""
|
||||
print("\n" + "═" * 60)
|
||||
print("【Step 1: 数据库分析与扩胞处理】")
|
||||
print("═" * 60)
|
||||
|
||||
# 数据库路径
|
||||
while True:
|
||||
db_path = input("\n📂 请输入数据库路径: ").strip()
|
||||
if os.path.exists(db_path):
|
||||
break
|
||||
print(f"❌ 路径不存在: {db_path}")
|
||||
|
||||
# 检测当前Conda环境路径
|
||||
default_conda = "/cluster/home/koko125/anaconda3/envs/screen"
|
||||
conda_env_path = env.get('conda_env', '') or default_conda
|
||||
|
||||
print(f"\n检测到Conda环境: {conda_env_path}")
|
||||
custom_env = input(f"使用此环境? [Y/n] 或输入其他路径: ").strip()
|
||||
if custom_env.lower() == 'n':
|
||||
conda_env_path = input("请输入Conda环境完整路径: ").strip()
|
||||
elif custom_env and custom_env.lower() != 'y':
|
||||
conda_env_path = custom_env
|
||||
|
||||
# 目标阴离子
|
||||
anion_input = input("🎯 请输入目标阴离子 (逗号分隔) [默认: O,S,Cl,Br]: ").strip()
|
||||
anions = set(a.strip() for a in anion_input.split(',')) if anion_input else {'O', 'S', 'Cl', 'Br'}
|
||||
|
||||
# 阴离子模式
|
||||
print("\n阴离子模式:")
|
||||
print(" 1. 仅单一阴离子")
|
||||
print(" 2. 仅复合阴离子")
|
||||
print(" 3. 全部 (默认)")
|
||||
mode_choice = input("请选择 [1/2/3]: ").strip()
|
||||
anion_mode = {'1': 'single', '2': 'mixed', '3': 'all', '': 'all'}.get(mode_choice, 'all')
|
||||
|
||||
# 并行配置
|
||||
print("\n" + "─" * 50)
|
||||
print("【并行计算配置】")
|
||||
|
||||
default_cores = min(env['total_cores'], 32)
|
||||
cores_input = input(f"💻 最大可用核数/Worker数 [默认: {default_cores}]: ").strip()
|
||||
max_workers = int(cores_input) if cores_input.isdigit() else default_cores
|
||||
|
||||
params = {
|
||||
'database_path': db_path,
|
||||
'target_cation': cation,
|
||||
'target_anions': anions,
|
||||
'anion_mode': anion_mode,
|
||||
'max_workers': max_workers,
|
||||
'conda_env': conda_env_path,
|
||||
}
|
||||
|
||||
print("\n" + "═" * 60)
|
||||
print("开始数据库分析...")
|
||||
print("═" * 60)
|
||||
|
||||
# 创建分析器
|
||||
analyzer = DatabaseAnalyzer(
|
||||
database_path=params['database_path'],
|
||||
target_cation=params['target_cation'],
|
||||
target_anions=params['target_anions'],
|
||||
anion_mode=params['anion_mode'],
|
||||
max_cores=params['max_workers'],
|
||||
task_complexity='medium'
|
||||
)
|
||||
|
||||
print(f"\n发现 {len(analyzer.cif_files)} 个CIF文件")
|
||||
|
||||
# 执行分析
|
||||
report = analyzer.analyze(show_progress=True)
|
||||
|
||||
# 打印报告
|
||||
ReportGenerator.print_report(report, detailed=True)
|
||||
|
||||
# 保存选项
|
||||
save_choice = input("\n是否保存报告? [y/N]: ").strip().lower()
|
||||
if save_choice == 'y':
|
||||
output_path = input("报告路径 [默认: analysis_report.json]: ").strip()
|
||||
output_path = output_path or "analysis_report.json"
|
||||
report.save(output_path)
|
||||
print(f"✅ 报告已保存到: {output_path}")
|
||||
|
||||
# CSV导出
|
||||
csv_choice = input("是否导出详细CSV? [y/N]: ").strip().lower()
|
||||
if csv_choice == 'y':
|
||||
csv_path = input("CSV路径 [默认: analysis_details.csv]: ").strip()
|
||||
csv_path = csv_path or "analysis_details.csv"
|
||||
ReportGenerator.export_to_csv(report, csv_path)
|
||||
|
||||
# 生成最终数据库
|
||||
process_choice = input("\n是否生成最终可用的数据库(扩胞+添加化合价)? [Y/n]: ").strip().lower()
|
||||
if process_choice == 'n':
|
||||
print("\n已跳过扩胞处理,可稍后继续")
|
||||
return params
|
||||
|
||||
# 输出目录设置
|
||||
print("\n输出目录设置:")
|
||||
flat_dir = input(" 原始格式输出目录 [默认: workspace/processed]: ").strip()
|
||||
flat_dir = flat_dir or "workspace/processed"
|
||||
|
||||
analysis_dir = input(" 分析格式输出目录 [默认: workspace/data]: ").strip()
|
||||
analysis_dir = analysis_dir or "workspace/data"
|
||||
|
||||
# 扩胞保存数量
|
||||
keep_input = input("\n扩胞结构保存数量 [默认: 1]: ").strip()
|
||||
keep_number = int(keep_input) if keep_input.isdigit() and int(keep_input) > 0 else 1
|
||||
|
||||
# 扩胞精度选择
|
||||
print("\n扩胞计算精度:")
|
||||
print(" 1. 高精度 (精确分数)")
|
||||
print(" 2. 普通精度 (分母≤100)")
|
||||
print(" 3. 低精度 (分母≤10) [默认]")
|
||||
print(" 4. 极低精度 (分母≤5)")
|
||||
precision_choice = input("请选择 [1/2/3/4]: ").strip()
|
||||
calculate_type = {
|
||||
'1': 'high', '2': 'normal', '3': 'low', '4': 'very_low', '': 'low'
|
||||
}.get(precision_choice, 'low')
|
||||
|
||||
# 获取可处理文件
|
||||
processable = report.get_processable_files(include_needs_expansion=True)
|
||||
|
||||
if not processable:
|
||||
print("⚠️ 没有可处理的文件")
|
||||
return params
|
||||
|
||||
print(f"\n发现 {len(processable)} 个可处理的文件")
|
||||
|
||||
# 准备文件列表和扩胞标记
|
||||
input_files = [info.file_path for info in processable]
|
||||
needs_expansion_flags = [info.needs_expansion for info in processable]
|
||||
anion_types_list = [info.anion_types for info in processable]
|
||||
|
||||
direct_count = sum(1 for f in needs_expansion_flags if not f)
|
||||
expansion_count = sum(1 for f in needs_expansion_flags if f)
|
||||
print(f" - 可直接处理: {direct_count}")
|
||||
print(f" - 需要扩胞: {expansion_count}")
|
||||
print(f" - 扩胞保存数: {keep_number}")
|
||||
|
||||
confirm = input("\n确认开始处理? [Y/n]: ").strip().lower()
|
||||
if confirm == 'n':
|
||||
print("\n已取消处理")
|
||||
return params
|
||||
|
||||
print("\n" + "═" * 60)
|
||||
print("开始处理结构文件...")
|
||||
print("═" * 60)
|
||||
|
||||
# 创建处理器
|
||||
processor = StructureProcessor(
|
||||
calculate_type=calculate_type,
|
||||
keep_number=keep_number,
|
||||
target_cation=params['target_cation']
|
||||
)
|
||||
|
||||
# 创建输出目录
|
||||
os.makedirs(flat_dir, exist_ok=True)
|
||||
os.makedirs(analysis_dir, exist_ok=True)
|
||||
|
||||
results = []
|
||||
total = len(input_files)
|
||||
|
||||
import shutil
|
||||
for i, (input_file, needs_exp, anion_types) in enumerate(
|
||||
zip(input_files, needs_expansion_flags, anion_types_list)
|
||||
):
|
||||
print(f"\r处理进度: {i+1}/{total} - {os.path.basename(input_file)}", end="")
|
||||
|
||||
# 处理文件到原始格式目录
|
||||
result = processor.process_file(input_file, flat_dir, needs_exp)
|
||||
results.append(result)
|
||||
|
||||
if result.success:
|
||||
# 同时保存到分析格式目录
|
||||
# 按阴离子类型创建子目录
|
||||
anion_key = '+'.join(sorted(anion_types)) if anion_types else 'other'
|
||||
anion_dir = os.path.join(analysis_dir, anion_key)
|
||||
os.makedirs(anion_dir, exist_ok=True)
|
||||
|
||||
# 获取基础文件名
|
||||
base_name = os.path.splitext(os.path.basename(input_file))[0]
|
||||
|
||||
# 创建以文件名命名的子目录
|
||||
file_dir = os.path.join(anion_dir, base_name)
|
||||
os.makedirs(file_dir, exist_ok=True)
|
||||
|
||||
# 复制生成的文件到分析格式目录
|
||||
for output_file in result.output_files:
|
||||
dst_path = os.path.join(file_dir, os.path.basename(output_file))
|
||||
shutil.copy2(output_file, dst_path)
|
||||
|
||||
print()
|
||||
|
||||
# 统计结果
|
||||
success_count = sum(1 for r in results if r.success)
|
||||
fail_count = sum(1 for r in results if not r.success)
|
||||
total_output = sum(len(r.output_files) for r in results if r.success)
|
||||
|
||||
print("\n" + "-" * 60)
|
||||
print("【处理结果统计】")
|
||||
print("-" * 60)
|
||||
print(f" 成功处理: {success_count}")
|
||||
print(f" 处理失败: {fail_count}")
|
||||
print(f" 生成文件: {total_output}")
|
||||
print(f"\n 原始格式目录: {flat_dir}")
|
||||
print(f" 分析格式目录: {analysis_dir}")
|
||||
print(f" └── 结构: data/阴离子类型/文件名/文件名.cif")
|
||||
|
||||
# 显示失败的文件
|
||||
if fail_count > 0:
|
||||
print("\n失败的文件:")
|
||||
for r in results:
|
||||
if not r.success:
|
||||
print(f" - {os.path.basename(r.input_file)}: {r.error_message}")
|
||||
|
||||
print("\n✅ Step 1 完成!")
|
||||
return params
|
||||
|
||||
|
||||
def run_step2_zeo_analysis(params: dict, ws_manager: WorkspaceManager = None, workspace_info = None):
|
||||
"""
|
||||
Step 2-4: Zeo++ Voronoi 分析
|
||||
|
||||
Args:
|
||||
params: 参数字典,包含 target_cation 等
|
||||
ws_manager: 工作区管理器(可选,如果已有则直接使用)
|
||||
workspace_info: 工作区信息(可选,如果已有则直接使用)
|
||||
"""
|
||||
|
||||
print("\n" + "═" * 60)
|
||||
print("【Step 2-4: Zeo++ Voronoi 分析】")
|
||||
print("═" * 60)
|
||||
|
||||
# 如果没有传入 ws_manager,则创建新的
|
||||
if ws_manager is None:
|
||||
# 工作区路径
|
||||
workspace_path = input("\n工作区路径 [默认: workspace]: ").strip() or "workspace"
|
||||
tool_dir = input("工具目录路径 [默认: tool]: ").strip() or "tool"
|
||||
|
||||
# 创建工作区管理器
|
||||
ws_manager = WorkspaceManager(
|
||||
workspace_path=workspace_path,
|
||||
tool_dir=tool_dir,
|
||||
target_cation=params.get('target_cation', 'Li')
|
||||
)
|
||||
|
||||
# 如果没有传入 workspace_info,则检查现有工作区
|
||||
if workspace_info is None:
|
||||
existing = ws_manager.check_existing_workspace()
|
||||
|
||||
if existing and existing.total_structures > 0:
|
||||
ws_manager.print_workspace_summary(existing)
|
||||
workspace_info = existing
|
||||
else:
|
||||
print("⚠️ 工作区数据目录不存在或为空")
|
||||
print(f" 请先运行 Step 1 生成处理后的数据到: {ws_manager.data_dir}")
|
||||
return
|
||||
|
||||
# 检查是否需要创建软链接
|
||||
if workspace_info.linked_structures < workspace_info.total_structures:
|
||||
print("\n正在创建配置文件软链接...")
|
||||
workspace_info = ws_manager.setup_workspace(force_relink=False)
|
||||
|
||||
# 获取计算任务
|
||||
tasks = ws_manager.get_computation_tasks(workspace_info)
|
||||
|
||||
if not tasks:
|
||||
print("⚠️ 没有找到可计算的任务")
|
||||
return
|
||||
|
||||
print(f"\n发现 {len(tasks)} 个计算任务")
|
||||
|
||||
# Zeo++ 环境配置
|
||||
print("\n" + "-" * 50)
|
||||
print("【Zeo++ 计算配置】")
|
||||
|
||||
default_zeo_env = "/cluster/home/koko125/anaconda3/envs/zeo"
|
||||
zeo_env = input(f"Zeo++ Conda环境 [默认: {default_zeo_env}]: ").strip()
|
||||
zeo_env = zeo_env or default_zeo_env
|
||||
|
||||
# SLURM 配置
|
||||
partition = input("SLURM分区 [默认: cpu]: ").strip() or "cpu"
|
||||
|
||||
max_concurrent_input = input("最大并发任务数 [默认: 50]: ").strip()
|
||||
max_concurrent = int(max_concurrent_input) if max_concurrent_input.isdigit() else 50
|
||||
|
||||
time_limit = input("单任务时间限制 [默认: 2:00:00]: ").strip() or "2:00:00"
|
||||
|
||||
# 创建配置
|
||||
zeo_config = ZeoConfig(
|
||||
conda_env=zeo_env,
|
||||
partition=partition,
|
||||
max_concurrent=max_concurrent,
|
||||
time_limit=time_limit
|
||||
)
|
||||
|
||||
# 确认执行
|
||||
print("\n" + "-" * 50)
|
||||
print("【计算任务确认】")
|
||||
print(f" 总任务数: {len(tasks)}")
|
||||
print(f" Conda环境: {zeo_config.conda_env}")
|
||||
print(f" SLURM分区: {zeo_config.partition}")
|
||||
print(f" 最大并发: {zeo_config.max_concurrent}")
|
||||
print(f" 时间限制: {zeo_config.time_limit}")
|
||||
|
||||
confirm = input("\n确认提交计算任务? [Y/n]: ").strip().lower()
|
||||
if confirm == 'n':
|
||||
print("已取消")
|
||||
return
|
||||
|
||||
# 创建执行器并运行
|
||||
executor = ZeoExecutor(zeo_config)
|
||||
|
||||
log_dir = os.path.join(ws_manager.workspace_path, "slurm_logs")
|
||||
results = executor.run_batch(tasks, output_dir=log_dir)
|
||||
|
||||
# 打印结果摘要
|
||||
executor.print_results_summary(results)
|
||||
|
||||
# 保存结果
|
||||
results_file = os.path.join(ws_manager.workspace_path, "zeo_results.json")
|
||||
results_data = [
|
||||
{
|
||||
'task_id': r.task_id,
|
||||
'structure_name': r.structure_name,
|
||||
'cif_path': r.cif_path,
|
||||
'success': r.success,
|
||||
'output_files': r.output_files,
|
||||
'error_message': r.error_message
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
with open(results_file, 'w') as f:
|
||||
json.dump(results_data, f, indent=2)
|
||||
print(f"\n结果已保存到: {results_file}")
|
||||
|
||||
print("\n✅ Step 2-4 完成!")
|
||||
|
||||
|
||||
def run_step5_result_processing(workspace_path: str = "workspace"):
|
||||
"""
|
||||
Step 5: 结果处理与筛选
|
||||
|
||||
Args:
|
||||
workspace_path: 工作区路径
|
||||
"""
|
||||
print("\n" + "═" * 60)
|
||||
print("【Step 5: 结果处理与筛选】")
|
||||
print("═" * 60)
|
||||
|
||||
# 创建结果处理器
|
||||
processor = ResultProcessor(workspace_path=workspace_path)
|
||||
|
||||
# 获取筛选条件
|
||||
print("\n设置筛选条件:")
|
||||
print(" (直接回车使用默认值,输入 0 表示不限制)")
|
||||
|
||||
# 最小渗透直径(默认改为 1.0)
|
||||
perc_input = input(" 最小渗透直径 (Å) [默认: 1.0]: ").strip()
|
||||
min_percolation = float(perc_input) if perc_input else 1.0
|
||||
|
||||
# 最小 d 值
|
||||
d_input = input(" 最小 d 值 [默认: 2.0]: ").strip()
|
||||
min_d = float(d_input) if d_input else 2.0
|
||||
|
||||
# 最大节点长度
|
||||
node_input = input(" 最大节点长度 (Å) [默认: 不限制]: ").strip()
|
||||
max_node = float(node_input) if node_input else float('inf')
|
||||
|
||||
# 创建筛选条件
|
||||
criteria = FilterCriteria(
|
||||
min_percolation_diameter=min_percolation,
|
||||
min_d_value=min_d,
|
||||
max_node_length=max_node
|
||||
)
|
||||
|
||||
# 执行处理
|
||||
results, stats = processor.process_and_filter(
|
||||
criteria=criteria,
|
||||
save_csv=True,
|
||||
copy_passed=True
|
||||
)
|
||||
|
||||
# 打印摘要
|
||||
processor.print_summary(results, stats)
|
||||
|
||||
# 显示输出位置
|
||||
print("\n输出文件位置:")
|
||||
print(f" 汇总 CSV: {workspace_path}/results/summary.csv")
|
||||
print(f" 分类 CSV: {workspace_path}/results/阴离子类型/阴离子类型.csv")
|
||||
print(f" 通过筛选: {workspace_path}/passed/阴离子类型/结构名/")
|
||||
|
||||
print("\n✅ Step 5 完成!")
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
print_banner()
|
||||
|
||||
# 环境检测
|
||||
env = detect_and_show_environment()
|
||||
|
||||
# 询问目标阳离子
|
||||
cation = input("\n🎯 请输入目标阳离子 [默认: Li]: ").strip() or "Li"
|
||||
|
||||
# 检测工作流程状态
|
||||
status = detect_workflow_status(workspace_path="workspace", target_cation=cation)
|
||||
print_workflow_status(status)
|
||||
|
||||
# 获取用户选择
|
||||
choice = get_user_choice(status)
|
||||
|
||||
if choice == 'exit':
|
||||
print("\n👋 再见!")
|
||||
return
|
||||
|
||||
# 根据选择执行相应步骤
|
||||
params = {'target_cation': cation}
|
||||
|
||||
if choice == 'step1':
|
||||
# 从头开始
|
||||
params = run_step1_database_analysis(env, cation)
|
||||
if params is None:
|
||||
return
|
||||
|
||||
# 询问是否继续
|
||||
continue_choice = input("\n是否继续进行 Zeo++ 计算? [Y/n]: ").strip().lower()
|
||||
if continue_choice == 'n':
|
||||
print("\n可稍后运行程序继续 Zeo++ 计算")
|
||||
return
|
||||
|
||||
# 重新检测状态
|
||||
status = detect_workflow_status(workspace_path="workspace", target_cation=cation)
|
||||
run_step2_zeo_analysis(params, status['ws_manager'], status['workspace_info'])
|
||||
|
||||
# 询问是否继续筛选
|
||||
continue_choice = input("\n是否继续进行结果筛选? [Y/n]: ").strip().lower()
|
||||
if continue_choice == 'n':
|
||||
print("\n可稍后运行程序继续结果筛选")
|
||||
return
|
||||
|
||||
run_step5_result_processing("workspace")
|
||||
|
||||
elif choice == 'step2':
|
||||
# 从 Zeo++ 计算开始
|
||||
run_step2_zeo_analysis(params, status['ws_manager'], status['workspace_info'])
|
||||
|
||||
# 询问是否继续筛选
|
||||
continue_choice = input("\n是否继续进行结果筛选? [Y/n]: ").strip().lower()
|
||||
if continue_choice == 'n':
|
||||
print("\n可稍后运行程序继续结果筛选")
|
||||
return
|
||||
|
||||
run_step5_result_processing("workspace")
|
||||
|
||||
elif choice == 'step5':
|
||||
# 直接进行结果筛选
|
||||
run_step5_result_processing("workspace")
|
||||
|
||||
print("\n✅ 全部完成!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
349
readme.md
349
readme.md
@@ -1,4 +1,4 @@
|
||||
# 高通量筛选与扩胞项目
|
||||
# 高通量筛选与扩胞项目 v2.3
|
||||
|
||||
## 环境配置需求
|
||||
|
||||
@@ -12,9 +12,29 @@
|
||||
### 2. screen 环境 (用于逻辑筛选与数据处理)
|
||||
* **Python**: 3.11.4
|
||||
* **核心库**: `pymatgen==2024.11.13`, `pandas` (新增,用于处理CSV)
|
||||
* **路径**: `/cluster/home/koko125/anaconda3/envs/screen`
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 方式一:使用新版主程序(推荐)
|
||||
|
||||
```bash
|
||||
# 激活 screen 环境
|
||||
conda activate /cluster/home/koko125/anaconda3/envs/screen
|
||||
|
||||
# 运行主程序
|
||||
python main.py
|
||||
```
|
||||
|
||||
主程序提供交互式界面,支持:
|
||||
- 数据库分析与筛选
|
||||
- 本地多进程并行
|
||||
- SLURM 直接提交(无需生成脚本文件)
|
||||
- 实时进度条显示
|
||||
- 扩胞处理与化合价添加
|
||||
|
||||
### 方式二:传统方式
|
||||
|
||||
1. **数据准备**:
|
||||
* 如果数据来源为 **Materials Project (MP)**,请将 CIF 文件放入 `data/input_pre`。
|
||||
* 如果数据来源为 **ICSD**,请直接将 CIF 文件放入 `data/input`。
|
||||
@@ -25,6 +45,52 @@
|
||||
bash main.sh
|
||||
```
|
||||
|
||||
## 新版功能特性 (v2.1)
|
||||
|
||||
### 执行模式
|
||||
|
||||
1. **本地多进程模式** (`local`)
|
||||
- 使用 Python multiprocessing 在本地并行执行
|
||||
- 适合小规模任务或测试
|
||||
|
||||
2. **SLURM 直接提交模式** (`slurm`)
|
||||
- 直接在 Python 中提交 SLURM 作业
|
||||
- 无需生成脚本文件
|
||||
- 实时监控作业状态和进度
|
||||
- 适合大规模高通量计算
|
||||
|
||||
### 进度显示
|
||||
|
||||
```
|
||||
分析CIF文件: |████████████████████░░░░░░░░░░░░░░░░░░░░| 500/1000 (50.0%) [0:05:23<0:05:20, 1.6it/s] ✓480 ✗20
|
||||
```
|
||||
|
||||
### 输出格式
|
||||
|
||||
处理后的文件支持两种输出格式:
|
||||
|
||||
1. **原始格式**(平铺)
|
||||
```
|
||||
workspace/processed/
|
||||
├── 1514027.cif
|
||||
├── 1514072.cif
|
||||
└── ...
|
||||
```
|
||||
|
||||
2. **分析格式**(按阴离子分类)
|
||||
```
|
||||
workspace/data/
|
||||
├── O/
|
||||
│ ├── 1514027/
|
||||
│ │ └── 1514027.cif
|
||||
│ └── 1514072/
|
||||
│ └── 1514072.cif
|
||||
├── S/
|
||||
│ └── ...
|
||||
└── Cl+O/
|
||||
└── ...
|
||||
```
|
||||
|
||||
## 处理流程详解
|
||||
|
||||
### Stage 1: 预处理与基础筛选 (Step 1)
|
||||
@@ -51,9 +117,12 @@
|
||||
|
||||
---
|
||||
|
||||
## 扩胞逻辑 (Step 5 - 待后续执行)
|
||||
## 扩胞逻辑 (Step 5)
|
||||
|
||||
目前扩胞逻辑维持原状,基于筛选后的结构进行处理。
|
||||
扩胞处理已集成到新版主程序中,支持:
|
||||
- 自动计算扩胞因子
|
||||
- 可选保存数量(当为1时不加后缀)
|
||||
- 自动添加化合价信息
|
||||
|
||||
### 算法分解
|
||||
1. **读取结构**: 解析 CIF 文件。
|
||||
@@ -71,4 +140,276 @@
|
||||
|
||||
### 假设条件
|
||||
* 只考虑两个原子在同一位置上的共占位情况。
|
||||
* 不考虑 Li 原子的共占位情况,对 Li 原子不做处理。
|
||||
* 不考虑 Li 原子的共占位情况,对 Li 原子不做处理。
|
||||
|
||||
## 项目结构
|
||||
|
||||
```
|
||||
screen/
|
||||
├── main.py # 主入口(新版)
|
||||
├── main.sh # 传统脚本入口
|
||||
├── readme.md # 本文档
|
||||
├── config/ # 配置文件
|
||||
│ ├── settings.yaml
|
||||
│ └── valence_states.yaml
|
||||
├── src/ # 源代码
|
||||
│ ├── analysis/ # 分析模块
|
||||
│ │ ├── database_analyzer.py
|
||||
│ │ ├── report_generator.py
|
||||
│ │ ├── structure_inspector.py
|
||||
│ │ └── worker.py
|
||||
│ ├── core/ # 核心模块
|
||||
│ │ ├── executor.py # 任务执行器(新)
|
||||
│ │ ├── scheduler.py # 调度器
|
||||
│ │ └── progress.py # 进度管理
|
||||
│ ├── preprocessing/ # 预处理模块
|
||||
│ │ ├── processor.py # 结构处理器
|
||||
│ │ └── ...
|
||||
│ └── utils/ # 工具函数
|
||||
├── py/ # 传统脚本
|
||||
├── tool/ # 工具和配置
|
||||
│ ├── analyze_voronoi_nodes.py
|
||||
│ └── Li/ # 化合价配置
|
||||
└── workspace/ # 工作区
|
||||
├── data/ # 分析格式输出
|
||||
└── processed/ # 原始格式输出
|
||||
```
|
||||
|
||||
## API 使用示例
|
||||
|
||||
### 使用执行器
|
||||
|
||||
```python
|
||||
from src.core.executor import create_executor, TaskExecutor
|
||||
|
||||
# 创建执行器
|
||||
executor = create_executor(
|
||||
mode="slurm", # 或 "local"
|
||||
max_workers=32,
|
||||
conda_env="/cluster/home/koko125/anaconda3/envs/screen"
|
||||
)
|
||||
|
||||
# 定义任务
|
||||
tasks = [
|
||||
(file_path, "Li", {"O", "S"})
|
||||
for file_path in cif_files
|
||||
]
|
||||
|
||||
# 执行
|
||||
from src.analysis.worker import analyze_single_file
|
||||
results = executor.run(tasks, analyze_single_file, desc="分析CIF文件")
|
||||
```
|
||||
|
||||
### 使用数据库分析器
|
||||
|
||||
```python
|
||||
from src.analysis.database_analyzer import DatabaseAnalyzer
|
||||
|
||||
analyzer = DatabaseAnalyzer(
|
||||
database_path="/path/to/cif/files",
|
||||
target_cation="Li",
|
||||
target_anions={"O", "S", "Cl", "Br"},
|
||||
anion_mode="all"
|
||||
)
|
||||
|
||||
report = analyzer.analyze(show_progress=True)
|
||||
report.save("analysis_report.json")
|
||||
```
|
||||
|
||||
### 使用 Zeo++ 执行器
|
||||
|
||||
```python
|
||||
from src.computation.workspace_manager import WorkspaceManager
|
||||
from src.computation.zeo_executor import ZeoExecutor, ZeoConfig
|
||||
|
||||
# 设置工作区
|
||||
ws_manager = WorkspaceManager(
|
||||
workspace_path="workspace",
|
||||
tool_dir="tool",
|
||||
target_cation="Li"
|
||||
)
|
||||
|
||||
# 创建软链接
|
||||
workspace_info = ws_manager.setup_workspace()
|
||||
|
||||
# 获取计算任务
|
||||
tasks = ws_manager.get_computation_tasks(workspace_info)
|
||||
|
||||
# 配置 Zeo++ 执行器
|
||||
config = ZeoConfig(
|
||||
conda_env="/cluster/home/koko125/anaconda3/envs/zeo",
|
||||
partition="cpu",
|
||||
max_concurrent=50,
|
||||
time_limit="2:00:00"
|
||||
)
|
||||
|
||||
# 执行计算
|
||||
executor = ZeoExecutor(config)
|
||||
results = executor.run_batch(tasks, output_dir="slurm_logs")
|
||||
executor.print_results_summary(results)
|
||||
```
|
||||
|
||||
## 新版功能特性 (v2.2)
|
||||
|
||||
### Zeo++ Voronoi 分析
|
||||
|
||||
新增 SLURM 作业数组支持,可高效调度大量 Zeo++ 计算任务:
|
||||
|
||||
1. **自动工作区管理**
|
||||
- 检测现有工作区数据
|
||||
- 自动创建配置文件软链接
|
||||
- 按阴离子类型组织目录结构
|
||||
|
||||
2. **SLURM 作业数组**
|
||||
- 使用 `--array` 参数批量提交任务
|
||||
- 支持最大并发数限制(如 `%50`)
|
||||
- 自动分批处理超大任务集
|
||||
|
||||
3. **实时进度监控**
|
||||
- 通过状态文件跟踪任务完成情况
|
||||
- 支持 Ctrl+C 中断监控(作业继续运行)
|
||||
- 自动收集输出文件
|
||||
|
||||
### 工作流程
|
||||
|
||||
```
|
||||
Step 1: 数据库分析
|
||||
↓
|
||||
Step 1.5: 扩胞处理 + 化合价添加
|
||||
↓
|
||||
Step 2-4: Zeo++ Voronoi 分析
|
||||
├── 创建软链接 (yaml 配置 + 计算脚本)
|
||||
├── 提交 SLURM 作业数组
|
||||
└── 监控进度并收集结果
|
||||
↓
|
||||
Step 5: 结果处理与筛选
|
||||
├── 从 log.txt 提取关键参数
|
||||
├── 汇总到 CSV 文件
|
||||
├── 应用筛选条件
|
||||
└── 复制通过筛选的结构到 passed/ 目录
|
||||
```
|
||||
|
||||
### 筛选条件
|
||||
|
||||
Step 5 支持以下筛选条件:
|
||||
- **最小渗透直径** (Percolation Diameter): 默认 1.0 Å
|
||||
- **最小 d 值** (Minimum of d): 默认 2.0
|
||||
- **最大节点长度** (Maximum Node Length): 默认不限制
|
||||
|
||||
### 日志输出
|
||||
|
||||
Zeo++ 计算的输出会重定向到每个结构目录下的 `log.txt` 文件:
|
||||
```bash
|
||||
python analyze_voronoi_nodes.py *.cif -i O.yaml > log.txt 2>&1
|
||||
```
|
||||
|
||||
日志中包含的关键信息:
|
||||
- `Percolation diameter (A): X.XX` - 渗透直径
|
||||
- `the minium of d\nX.XX` - 最小 d 值
|
||||
- `Maximum node length detected: X.XX A` - 最大节点长度
|
||||
|
||||
### 目录结构
|
||||
|
||||
```
|
||||
workspace/
|
||||
├── data/ # 分析格式数据
|
||||
│ ├── O/ # 氧化物
|
||||
│ │ ├── O.yaml -> tool/Li/O.yaml
|
||||
│ │ ├── analyze_voronoi_nodes.py -> tool/analyze_voronoi_nodes.py
|
||||
│ │ ├── 1514027/
|
||||
│ │ │ ├── 1514027.cif
|
||||
│ │ │ ├── 1514027_all_accessed_node.cif # Zeo++ 输出
|
||||
│ │ │ ├── 1514027_bond_valence_filtered.cif
|
||||
│ │ │ └── 1514027_bv_info.csv
|
||||
│ │ └── ...
|
||||
│ ├── S/ # 硫化物
|
||||
│ └── Cl+O/ # 复合阴离子
|
||||
├── processed/ # 原始格式数据
|
||||
├── slurm_logs/ # SLURM 日志
|
||||
│ ├── tasks.json
|
||||
│ ├── submit_array.sh
|
||||
│ ├── task_0.out
|
||||
│ ├── task_0.err
|
||||
│ ├── status_0.txt
|
||||
│ └── ...
|
||||
└── zeo_results.json # 计算结果汇总
|
||||
```
|
||||
|
||||
### 配置文件说明
|
||||
|
||||
`tool/Li/O.yaml` 示例:
|
||||
```yaml
|
||||
SPECIE: Li+
|
||||
ANION: O
|
||||
PERCO_R: 0.5
|
||||
NEIGHBOR: 1.8
|
||||
LONG: 2.2
|
||||
```
|
||||
|
||||
参数说明:
|
||||
- `SPECIE`: 目标扩散离子(带电荷)
|
||||
- `ANION`: 阴离子类型
|
||||
- `PERCO_R`: 渗透半径阈值
|
||||
- `NEIGHBOR`: 邻近距离阈值
|
||||
- `LONG`: 长节点判定阈值
|
||||
|
||||
## 新版功能特性 (v2.3)
|
||||
|
||||
### 断点续做功能
|
||||
|
||||
v2.3 新增智能断点续做功能,支持从任意步骤继续执行:
|
||||
|
||||
1. **自动状态检测**
|
||||
- 启动时自动检测工作流程状态
|
||||
- 检测 `workspace/data/` 是否存在且有结构 → 判断 Step 1 是否完成
|
||||
- 检测结构目录下是否有 `log.txt` → 判断 Zeo++ 计算是否完成
|
||||
- 如果 50% 以上结构有 log.txt,认为 Zeo++ 计算已完成
|
||||
|
||||
2. **智能流程跳转**
|
||||
- 如果已完成 Zeo++ 计算 → 可直接进行筛选
|
||||
- 如果已完成扩胞处理 → 可直接进行 Zeo++ 计算
|
||||
- 从后往前检测,自动跳过已完成的步骤
|
||||
|
||||
3. **分步执行与中断**
|
||||
- 每个大步骤完成后询问是否继续
|
||||
- 支持中途退出,下次运行时自动检测进度
|
||||
- 三大步骤:扩胞与加化合价 → Zeo++ 计算 → 筛选
|
||||
|
||||
### 工作流程状态示例
|
||||
|
||||
```
|
||||
工作流程状态检测
|
||||
|
||||
检测到现有工作区数据:
|
||||
- 结构总数: 1234
|
||||
- 已完成 Zeo++ 计算: 1200 (97.2%)
|
||||
- 未完成 Zeo++ 计算: 34
|
||||
|
||||
当前状态: Zeo++ 计算已完成
|
||||
|
||||
可选操作:
|
||||
[1] 直接进行结果筛选 (Step 5)
|
||||
[2] 重新运行 Zeo++ 计算 (Step 2-4)
|
||||
[3] 从头开始 (Step 1)
|
||||
[0] 退出
|
||||
|
||||
请选择 [1]:
|
||||
```
|
||||
|
||||
### 使用场景
|
||||
|
||||
1. **首次运行**
|
||||
- 从 Step 1 开始完整执行
|
||||
- 每步完成后可选择继续或退出
|
||||
|
||||
2. **中断后继续**
|
||||
- 自动检测已完成的步骤
|
||||
- 提供从当前进度继续的选项
|
||||
|
||||
3. **重新筛选**
|
||||
- 修改筛选条件后
|
||||
- 可直接运行 Step 5 而无需重新计算
|
||||
|
||||
4. **部分重算**
|
||||
- 如需重新计算部分结构
|
||||
- 可选择重新运行 Zeo++ 计算
|
||||
|
||||
12
src/__init__.py
Normal file
12
src/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
高通量筛选与扩胞项目 - 源代码包
|
||||
"""
|
||||
|
||||
from . import analysis
|
||||
from . import core
|
||||
from . import preprocessing
|
||||
from . import computation
|
||||
from . import utils
|
||||
|
||||
__version__ = "2.2.0"
|
||||
__all__ = ['analysis', 'core', 'preprocessing', 'computation', 'utils']
|
||||
0
src/analysis/__init__.py
Normal file
0
src/analysis/__init__.py
Normal file
468
src/analysis/database_analyzer.py
Normal file
468
src/analysis/database_analyzer.py
Normal file
@@ -0,0 +1,468 @@
|
||||
"""
|
||||
数据库分析器:支持高性能并行分析
|
||||
"""
|
||||
import os
|
||||
import pickle
|
||||
import json
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from typing import Dict, List, Set, Optional
|
||||
from pathlib import Path
|
||||
|
||||
from .structure_inspector import StructureInspector, StructureInfo
|
||||
from .worker import analyze_single_file
|
||||
from ..core.scheduler import ParallelScheduler, ResourceConfig
|
||||
|
||||
|
||||
# 在 DatabaseReport 类中添加缺失的字段
|
||||
|
||||
@dataclass
|
||||
class DatabaseReport:
|
||||
"""数据库分析报告"""
|
||||
|
||||
# 基础统计
|
||||
database_path: str = ""
|
||||
total_files: int = 0
|
||||
valid_files: int = 0
|
||||
invalid_files: int = 0
|
||||
|
||||
# 目标元素统计
|
||||
target_cation: str = ""
|
||||
target_anions: Set[str] = field(default_factory=set)
|
||||
anion_mode: str = ""
|
||||
|
||||
# 含目标阳离子的统计
|
||||
cation_containing_count: int = 0
|
||||
cation_containing_ratio: float = 0.0
|
||||
|
||||
# 阴离子分布
|
||||
anion_distribution: Dict[str, int] = field(default_factory=dict)
|
||||
anion_ratios: Dict[str, float] = field(default_factory=dict)
|
||||
single_anion_count: int = 0
|
||||
mixed_anion_count: int = 0
|
||||
|
||||
# 数据质量统计
|
||||
with_oxidation_states: int = 0
|
||||
without_oxidation_states: int = 0
|
||||
needs_expansion_count: int = 0
|
||||
cation_with_vacancy_count: int = 0 # Li与空位共占位(新增)
|
||||
cation_with_other_cation_count: int = 0 # Li与其他阳离子共占位(新增)
|
||||
anion_partial_occupancy_count: int = 0
|
||||
binary_compound_count: int = 0
|
||||
has_water_count: int = 0
|
||||
has_radioactive_count: int = 0
|
||||
|
||||
# 可处理性统计
|
||||
directly_processable: int = 0
|
||||
needs_preprocessing: int = 0
|
||||
cannot_process: int = 0
|
||||
|
||||
# 详细信息
|
||||
all_structures: List[StructureInfo] = field(default_factory=list)
|
||||
skip_reasons_summary: Dict[str, int] = field(default_factory=dict)
|
||||
|
||||
# 扩胞相关统计(新增)
|
||||
expansion_stats: Dict[str, int] = field(default_factory=lambda: {
|
||||
'no_expansion_needed': 0,
|
||||
'expansion_factor_2': 0,
|
||||
'expansion_factor_3': 0,
|
||||
'expansion_factor_4_8': 0,
|
||||
'expansion_factor_large': 0,
|
||||
'cannot_expand': 0,
|
||||
})
|
||||
expansion_factor_distribution: Dict[int, int] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""转换为可序列化的字典"""
|
||||
from dataclasses import fields as dataclass_fields
|
||||
|
||||
def convert_value(val):
|
||||
"""递归转换值为可序列化类型"""
|
||||
if isinstance(val, set):
|
||||
return list(val)
|
||||
elif isinstance(val, dict):
|
||||
return {k: convert_value(v) for k, v in val.items()}
|
||||
elif isinstance(val, list):
|
||||
return [convert_value(item) for item in val]
|
||||
elif hasattr(val, '__dataclass_fields__'):
|
||||
# 处理 dataclass 对象
|
||||
return {k: convert_value(v) for k, v in asdict(val).items()}
|
||||
else:
|
||||
return val
|
||||
|
||||
result = {}
|
||||
for f in dataclass_fields(self):
|
||||
value = getattr(self, f.name)
|
||||
result[f.name] = convert_value(value)
|
||||
|
||||
return result
|
||||
|
||||
def save(self, path: str):
|
||||
"""保存报告到JSON文件"""
|
||||
with open(path, 'w', encoding='utf-8') as f:
|
||||
json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
|
||||
print(f"✅ 报告已保存到: {path}")
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: str) -> 'DatabaseReport':
|
||||
"""从JSON文件加载报告"""
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
d = json.load(f)
|
||||
|
||||
# 处理 set 类型
|
||||
if 'target_anions' in d:
|
||||
d['target_anions'] = set(d['target_anions'])
|
||||
# 处理 StructureInfo 列表(简化处理,不恢复完整对象)
|
||||
if 'all_structures' in d:
|
||||
d['all_structures'] = []
|
||||
|
||||
return cls(**d)
|
||||
|
||||
def get_processable_files(self, include_needs_expansion: bool = True) -> List[StructureInfo]:
|
||||
"""
|
||||
获取可处理的文件列表
|
||||
|
||||
Args:
|
||||
include_needs_expansion: 是否包含需要扩胞的文件
|
||||
|
||||
Returns:
|
||||
可处理的 StructureInfo 列表
|
||||
"""
|
||||
result = []
|
||||
for info in self.all_structures:
|
||||
if info is None or not info.is_valid:
|
||||
continue
|
||||
if not info.contains_target_cation:
|
||||
continue
|
||||
if not info.can_process:
|
||||
continue
|
||||
if not include_needs_expansion and info.needs_expansion:
|
||||
continue
|
||||
result.append(info)
|
||||
return result
|
||||
|
||||
def copy_processable_files(
|
||||
self,
|
||||
output_dir: str,
|
||||
include_needs_expansion: bool = True,
|
||||
organize_by_anion: bool = True
|
||||
) -> Dict[str, int]:
|
||||
"""
|
||||
将可处理的CIF文件复制到工作区
|
||||
|
||||
Args:
|
||||
output_dir: 输出目录(如 workspace/data)
|
||||
include_needs_expansion: 是否包含需要扩胞的文件
|
||||
organize_by_anion: 是否按阴离子类型组织子目录
|
||||
|
||||
Returns:
|
||||
复制统计信息 {类别: 数量}
|
||||
"""
|
||||
import shutil
|
||||
|
||||
# 创建输出目录
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 获取可处理文件
|
||||
processable = self.get_processable_files(include_needs_expansion)
|
||||
|
||||
stats = {
|
||||
'direct': 0, # 可直接处理
|
||||
'needs_expansion': 0, # 需要扩胞
|
||||
'total': 0
|
||||
}
|
||||
|
||||
# 按类型创建子目录
|
||||
if organize_by_anion:
|
||||
anion_dirs = {}
|
||||
|
||||
for info in processable:
|
||||
# 确定目标目录
|
||||
if organize_by_anion and info.anion_types:
|
||||
# 使用主要阴离子作为目录名
|
||||
anion_key = '+'.join(sorted(info.anion_types))
|
||||
if anion_key not in anion_dirs:
|
||||
anion_dir = os.path.join(output_dir, anion_key)
|
||||
os.makedirs(anion_dir, exist_ok=True)
|
||||
anion_dirs[anion_key] = anion_dir
|
||||
target_dir = anion_dirs[anion_key]
|
||||
else:
|
||||
target_dir = output_dir
|
||||
|
||||
# 进一步按处理类型分类
|
||||
if info.needs_expansion:
|
||||
sub_dir = os.path.join(target_dir, 'needs_expansion')
|
||||
stats['needs_expansion'] += 1
|
||||
else:
|
||||
sub_dir = os.path.join(target_dir, 'direct')
|
||||
stats['direct'] += 1
|
||||
|
||||
os.makedirs(sub_dir, exist_ok=True)
|
||||
|
||||
# 复制文件
|
||||
src_path = info.file_path
|
||||
dst_path = os.path.join(sub_dir, info.file_name)
|
||||
|
||||
try:
|
||||
shutil.copy2(src_path, dst_path)
|
||||
stats['total'] += 1
|
||||
except Exception as e:
|
||||
print(f"⚠️ 复制失败 {info.file_name}: {e}")
|
||||
|
||||
# 打印统计
|
||||
print(f"\n📁 文件已复制到: {output_dir}")
|
||||
print(f" 可直接处理: {stats['direct']}")
|
||||
print(f" 需要扩胞: {stats['needs_expansion']}")
|
||||
print(f" 总计: {stats['total']}")
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
class DatabaseAnalyzer:
|
||||
"""数据库分析器 - 支持高性能并行"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
database_path: str,
|
||||
target_cation: str = "Li",
|
||||
target_anions: Set[str] = None,
|
||||
anion_mode: str = "all",
|
||||
max_cores: int = 4,
|
||||
task_complexity: str = "medium"
|
||||
):
|
||||
"""
|
||||
初始化分析器
|
||||
|
||||
Args:
|
||||
database_path: 数据库路径
|
||||
target_cation: 目标阳离子
|
||||
target_anions: 目标阴离子集合
|
||||
anion_mode: 阴离子模式
|
||||
max_cores: 最大可用核数
|
||||
task_complexity: 任务复杂度 ('low', 'medium', 'high')
|
||||
"""
|
||||
self.database_path = database_path
|
||||
self.target_cation = target_cation
|
||||
self.target_anions = target_anions or {'O', 'S', 'Cl', 'Br'}
|
||||
self.anion_mode = anion_mode
|
||||
self.max_cores = max_cores
|
||||
self.task_complexity = task_complexity
|
||||
|
||||
# 获取文件列表
|
||||
self.cif_files = self._get_cif_files()
|
||||
|
||||
# 配置调度器
|
||||
self.resource_config = ParallelScheduler.recommend_config(
|
||||
num_tasks=len(self.cif_files),
|
||||
task_complexity=task_complexity,
|
||||
max_cores=max_cores
|
||||
)
|
||||
self.scheduler = ParallelScheduler(self.resource_config)
|
||||
|
||||
def _get_cif_files(self) -> List[str]:
|
||||
"""获取所有CIF文件路径"""
|
||||
cif_files = []
|
||||
|
||||
if os.path.isfile(self.database_path):
|
||||
if self.database_path.endswith('.cif'):
|
||||
cif_files.append(self.database_path)
|
||||
else:
|
||||
for root, dirs, files in os.walk(self.database_path):
|
||||
for f in files:
|
||||
if f.endswith('.cif'):
|
||||
cif_files.append(os.path.join(root, f))
|
||||
|
||||
return sorted(cif_files)
|
||||
|
||||
def analyze(self, show_progress: bool = True) -> DatabaseReport:
|
||||
"""
|
||||
执行并行分析
|
||||
|
||||
Args:
|
||||
show_progress: 是否显示进度
|
||||
|
||||
Returns:
|
||||
DatabaseReport: 分析报告
|
||||
"""
|
||||
report = DatabaseReport(
|
||||
database_path=self.database_path,
|
||||
target_cation=self.target_cation,
|
||||
target_anions=self.target_anions,
|
||||
anion_mode=self.anion_mode,
|
||||
total_files=len(self.cif_files)
|
||||
)
|
||||
|
||||
if report.total_files == 0:
|
||||
print(f"⚠️ 警告: 在 {self.database_path} 中未找到CIF文件")
|
||||
return report
|
||||
|
||||
# 准备任务
|
||||
tasks = [
|
||||
(f, self.target_cation, self.target_anions)
|
||||
for f in self.cif_files
|
||||
]
|
||||
|
||||
# 执行并行分析
|
||||
results = self.scheduler.run_local(
|
||||
tasks=tasks,
|
||||
worker_func=analyze_single_file,
|
||||
desc="分析CIF文件"
|
||||
)
|
||||
|
||||
# 过滤有效结果
|
||||
report.all_structures = [r for r in results if r is not None]
|
||||
|
||||
# 统计结果
|
||||
self._compute_statistics(report)
|
||||
|
||||
return report
|
||||
|
||||
def analyze_slurm(
|
||||
self,
|
||||
output_dir: str,
|
||||
job_name: str = "cif_analysis"
|
||||
) -> str:
|
||||
"""
|
||||
提交SLURM作业进行分析
|
||||
|
||||
Args:
|
||||
output_dir: 输出目录
|
||||
job_name: 作业名称
|
||||
|
||||
Returns:
|
||||
作业ID
|
||||
"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 保存任务配置
|
||||
tasks_file = os.path.join(output_dir, "tasks.json")
|
||||
with open(tasks_file, 'w') as f:
|
||||
json.dump({
|
||||
'files': self.cif_files,
|
||||
'target_cation': self.target_cation,
|
||||
'target_anions': list(self.target_anions),
|
||||
'anion_mode': self.anion_mode
|
||||
}, f)
|
||||
|
||||
# 生成SLURM脚本
|
||||
worker_script = os.path.join(
|
||||
os.path.dirname(__file__), 'worker.py'
|
||||
)
|
||||
script = self.scheduler.generate_slurm_script(
|
||||
tasks_file=tasks_file,
|
||||
worker_script=worker_script,
|
||||
output_dir=output_dir,
|
||||
job_name=job_name
|
||||
)
|
||||
|
||||
# 保存并提交
|
||||
script_path = os.path.join(output_dir, "submit.sh")
|
||||
return self.scheduler.submit_slurm_job(script, script_path)
|
||||
|
||||
|
||||
# 更新 _compute_statistics 方法
|
||||
|
||||
def _compute_statistics(self, report: DatabaseReport):
|
||||
"""计算统计数据"""
|
||||
|
||||
for info in report.all_structures:
|
||||
# 确保 info 不是 None
|
||||
if info is None:
|
||||
report.invalid_files += 1
|
||||
continue
|
||||
|
||||
# 检查有效性
|
||||
if info.is_valid:
|
||||
report.valid_files += 1
|
||||
else:
|
||||
report.invalid_files += 1
|
||||
continue # 无效文件不继续统计
|
||||
|
||||
# 关键修复:只有当结构确实含有目标阳离子时才计入统计
|
||||
if not info.contains_target_cation:
|
||||
continue # 不含目标阳离子的文件不继续统计
|
||||
|
||||
report.cation_containing_count += 1
|
||||
|
||||
for anion in info.anion_types:
|
||||
report.anion_distribution[anion] = \
|
||||
report.anion_distribution.get(anion, 0) + 1
|
||||
|
||||
if info.anion_mode == "single":
|
||||
report.single_anion_count += 1
|
||||
elif info.anion_mode == "mixed":
|
||||
report.mixed_anion_count += 1
|
||||
|
||||
# 根据阴离子模式过滤
|
||||
if self.anion_mode == "single" and info.anion_mode != "single":
|
||||
continue
|
||||
if self.anion_mode == "mixed" and info.anion_mode != "mixed":
|
||||
continue
|
||||
if info.anion_mode == "none":
|
||||
continue
|
||||
|
||||
# 各项统计
|
||||
if info.has_oxidation_states:
|
||||
report.with_oxidation_states += 1
|
||||
else:
|
||||
report.without_oxidation_states += 1
|
||||
|
||||
# Li共占位统计(修改)
|
||||
if info.cation_with_vacancy:
|
||||
report.cation_with_vacancy_count += 1
|
||||
if info.cation_with_other_cation:
|
||||
report.cation_with_other_cation_count += 1
|
||||
|
||||
if info.anion_has_partial_occupancy:
|
||||
report.anion_partial_occupancy_count += 1
|
||||
if info.is_binary_compound:
|
||||
report.binary_compound_count += 1
|
||||
if info.has_water_molecule:
|
||||
report.has_water_count += 1
|
||||
if info.has_radioactive_elements:
|
||||
report.has_radioactive_count += 1
|
||||
|
||||
# 可处理性
|
||||
if info.can_process:
|
||||
if info.needs_expansion:
|
||||
report.needs_preprocessing += 1
|
||||
else:
|
||||
report.directly_processable += 1
|
||||
else:
|
||||
report.cannot_process += 1
|
||||
if info.skip_reason:
|
||||
for reason in info.skip_reason.split("; "):
|
||||
report.skip_reasons_summary[reason] = \
|
||||
report.skip_reasons_summary.get(reason, 0) + 1
|
||||
|
||||
# 扩胞统计(新增)
|
||||
exp_info = info.expansion_info
|
||||
factor = exp_info.expansion_factor
|
||||
|
||||
if not exp_info.needs_expansion:
|
||||
report.expansion_stats['no_expansion_needed'] += 1
|
||||
elif not exp_info.can_expand:
|
||||
report.expansion_stats['cannot_expand'] += 1
|
||||
elif factor == 2:
|
||||
report.expansion_stats['expansion_factor_2'] += 1
|
||||
elif factor == 3:
|
||||
report.expansion_stats['expansion_factor_3'] += 1
|
||||
elif 4 <= factor <= 8:
|
||||
report.expansion_stats['expansion_factor_4_8'] += 1
|
||||
else:
|
||||
report.expansion_stats['expansion_factor_large'] += 1
|
||||
|
||||
# 详细分布
|
||||
if exp_info.needs_expansion and exp_info.can_expand:
|
||||
report.expansion_factor_distribution[factor] = \
|
||||
report.expansion_factor_distribution.get(factor, 0) + 1
|
||||
report.needs_expansion_count += 1
|
||||
|
||||
# 计算比例
|
||||
if report.valid_files > 0:
|
||||
report.cation_containing_ratio = \
|
||||
report.cation_containing_count / report.valid_files
|
||||
|
||||
if report.cation_containing_count > 0:
|
||||
for anion, count in report.anion_distribution.items():
|
||||
report.anion_ratios[anion] = \
|
||||
count / report.cation_containing_count
|
||||
159
src/analysis/report_generator.py
Normal file
159
src/analysis/report_generator.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""
|
||||
报告生成器:生成格式化的分析报告
|
||||
"""
|
||||
from typing import Optional
|
||||
from .database_analyzer import DatabaseReport
|
||||
|
||||
|
||||
class ReportGenerator:
|
||||
"""报告生成器"""
|
||||
|
||||
@staticmethod
|
||||
def print_report(report: DatabaseReport, detailed: bool = False):
|
||||
"""打印分析报告"""
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print(" 数据库分析报告")
|
||||
print("=" * 70)
|
||||
|
||||
# 基础信息
|
||||
print(f"\n📁 数据库路径: {report.database_path}")
|
||||
print(f"🎯 目标阳离子: {report.target_cation}")
|
||||
print(f"🎯 目标阴离子: {', '.join(sorted(report.target_anions))}")
|
||||
print(f"🎯 阴离子模式: {report.anion_mode}")
|
||||
|
||||
# 基础统计
|
||||
print("\n" + "-" * 70)
|
||||
print("【1. 基础统计】")
|
||||
print("-" * 70)
|
||||
print(f" 总 CIF 文件数: {report.total_files}")
|
||||
print(f" 有效文件数: {report.valid_files}")
|
||||
print(f" 无效文件数: {report.invalid_files}")
|
||||
print(f" 含 {report.target_cation} 化合物数: {report.cation_containing_count}")
|
||||
print(f" 含 {report.target_cation} 化合物占比: {report.cation_containing_ratio:.1%}")
|
||||
|
||||
# 阴离子分布
|
||||
print("\n" + "-" * 70)
|
||||
print(f"【2. 阴离子分布】(在含 {report.target_cation} 的化合物中)")
|
||||
print("-" * 70)
|
||||
|
||||
if report.anion_distribution:
|
||||
for anion in sorted(report.anion_distribution.keys()):
|
||||
count = report.anion_distribution[anion]
|
||||
ratio = report.anion_ratios.get(anion, 0)
|
||||
bar = "█" * int(ratio * 30)
|
||||
print(f" {anion:5s}: {count:6d} ({ratio:6.1%}) {bar}")
|
||||
|
||||
print(f"\n 单一阴离子化合物: {report.single_anion_count}")
|
||||
print(f" 复合阴离子化合物: {report.mixed_anion_count}")
|
||||
|
||||
# 数据质量
|
||||
print("\n" + "-" * 70)
|
||||
print("【3. 数据质量检查】")
|
||||
print("-" * 70)
|
||||
|
||||
total_target = report.cation_containing_count
|
||||
if total_target > 0:
|
||||
print(f" 含化合价信息: {report.with_oxidation_states:6d} "
|
||||
f"({report.with_oxidation_states / total_target:.1%})")
|
||||
print(f" 缺化合价信息: {report.without_oxidation_states:6d} "
|
||||
f"({report.without_oxidation_states / total_target:.1%})")
|
||||
print()
|
||||
print(f" {report.target_cation}与空位共占位(无需处理): {report.cation_with_vacancy_count:6d}")
|
||||
print(f" {report.target_cation}与阳离子共占位(需扩胞): {report.cation_with_other_cation_count:6d}")
|
||||
print(f" 阴离子共占位: {report.anion_partial_occupancy_count:6d}")
|
||||
print(f" 需扩胞处理(总计): {report.needs_expansion_count:6d}")
|
||||
print()
|
||||
print(f" 二元化合物: {report.binary_compound_count:6d}")
|
||||
print(f" 含水分子: {report.has_water_count:6d}")
|
||||
print(f" 含放射性元素: {report.has_radioactive_count:6d}")
|
||||
|
||||
# 可处理性评估
|
||||
print("\n" + "-" * 70)
|
||||
print("【4. 可处理性评估】")
|
||||
print("-" * 70)
|
||||
|
||||
total_processable = report.directly_processable + report.needs_preprocessing
|
||||
print(f" ✅ 可直接处理: {report.directly_processable:6d}")
|
||||
print(f" ⚠️ 需预处理(扩胞): {report.needs_preprocessing:6d}")
|
||||
print(f" ❌ 无法处理: {report.cannot_process:6d}")
|
||||
print(f" ─────────────────────────────")
|
||||
print(f" 📊 可处理总数: {total_processable:6d}")
|
||||
|
||||
# 跳过原因汇总
|
||||
if report.skip_reasons_summary:
|
||||
print("\n" + "-" * 70)
|
||||
print("【5. 无法处理的原因统计】")
|
||||
print("-" * 70)
|
||||
sorted_reasons = sorted(
|
||||
report.skip_reasons_summary.items(),
|
||||
key=lambda x: x[1],
|
||||
reverse=True
|
||||
)
|
||||
for reason, count in sorted_reasons:
|
||||
print(f" {reason:35s}: {count:6d}")
|
||||
|
||||
# 扩胞分析
|
||||
print("\n" + "-" * 70)
|
||||
print("【6. 扩胞需求分析】")
|
||||
print("-" * 70)
|
||||
|
||||
exp = report.expansion_stats
|
||||
if total_processable > 0:
|
||||
print(f" 无需扩胞: {exp['no_expansion_needed']:6d}")
|
||||
print(f" 扩胞因子=2: {exp['expansion_factor_2']:6d}")
|
||||
print(f" 扩胞因子=3: {exp['expansion_factor_3']:6d}")
|
||||
print(f" 扩胞因子=4~8: {exp['expansion_factor_4_8']:6d}")
|
||||
print(f" 扩胞因子>8: {exp['expansion_factor_large']:6d}")
|
||||
print(f" 无法扩胞(因子过大): {exp['cannot_expand']:6d}")
|
||||
|
||||
# 详细分布
|
||||
if detailed and report.expansion_factor_distribution:
|
||||
print("\n 扩胞因子分布:")
|
||||
for factor in sorted(report.expansion_factor_distribution.keys()):
|
||||
count = report.expansion_factor_distribution[factor]
|
||||
bar = "█" * min(count, 30)
|
||||
print(f" {factor:3d}x: {count:5d} {bar}")
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
@staticmethod
|
||||
def export_to_csv(report: DatabaseReport, output_path: str):
|
||||
"""导出详细结果到CSV"""
|
||||
import csv
|
||||
|
||||
with open(output_path, 'w', newline='', encoding='utf-8') as f:
|
||||
writer = csv.writer(f)
|
||||
|
||||
# 写入表头
|
||||
headers = [
|
||||
'file_name', 'is_valid', 'contains_target_cation',
|
||||
'anion_types', 'anion_mode', 'has_oxidation_states',
|
||||
'has_partial_occupancy', 'cation_partial_occupancy',
|
||||
'anion_partial_occupancy', 'needs_expansion',
|
||||
'is_binary', 'has_water', 'has_radioactive',
|
||||
'can_process', 'skip_reason'
|
||||
]
|
||||
writer.writerow(headers)
|
||||
|
||||
# 写入数据
|
||||
for info in report.all_structures:
|
||||
row = [
|
||||
info.file_name,
|
||||
info.is_valid,
|
||||
info.contains_target_cation,
|
||||
'+'.join(sorted(info.anion_types)) if info.anion_types else '',
|
||||
info.anion_mode,
|
||||
info.has_oxidation_states,
|
||||
info.has_partial_occupancy,
|
||||
info.cation_with_other_cation, # 修复:使用正确的属性名
|
||||
info.anion_has_partial_occupancy,
|
||||
info.needs_expansion,
|
||||
info.is_binary_compound,
|
||||
info.has_water_molecule,
|
||||
info.has_radioactive_elements,
|
||||
info.can_process,
|
||||
info.skip_reason
|
||||
]
|
||||
writer.writerow(row)
|
||||
|
||||
print(f"详细结果已导出到: {output_path}")
|
||||
443
src/analysis/structure_inspector.py
Normal file
443
src/analysis/structure_inspector.py
Normal file
@@ -0,0 +1,443 @@
|
||||
"""
|
||||
结构检查器:对单个CIF文件进行深度分析(含扩胞需求判断)
|
||||
"""
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Set, Dict, List, Optional, Tuple
|
||||
from pymatgen.core import Structure
|
||||
from pymatgen.core.periodic_table import Element, Specie
|
||||
from collections import defaultdict
|
||||
from fractions import Fraction
|
||||
from functools import reduce
|
||||
import math
|
||||
import re
|
||||
import os
|
||||
|
||||
|
||||
@dataclass
|
||||
class OccupancyInfo:
|
||||
"""共占位信息"""
|
||||
occupation: float # 占据率
|
||||
atom_serials: List[int] = field(default_factory=list) # 原子序号
|
||||
elements: List[str] = field(default_factory=list) # 涉及的元素
|
||||
numerator: int = 0 # 分子
|
||||
denominator: int = 1 # 分母
|
||||
involves_target_cation: bool = False # 是否涉及目标阳离子
|
||||
involves_anion: bool = False # 是否涉及阴离子
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExpansionInfo:
|
||||
"""扩胞信息"""
|
||||
needs_expansion: bool = False # 是否需要扩胞
|
||||
expansion_factor: int = 1 # 扩胞因子(最小公倍数)
|
||||
occupancy_details: List[OccupancyInfo] = field(default_factory=list) # 共占位详情
|
||||
problematic_sites: int = 0 # 问题位点数
|
||||
can_expand: bool = True # 是否可以扩胞处理
|
||||
skip_reason: str = "" # 无法扩胞的原因
|
||||
|
||||
|
||||
@dataclass
|
||||
class StructureInfo:
|
||||
"""单个结构的分析结果"""
|
||||
file_path: str
|
||||
file_name: str
|
||||
|
||||
# 基础信息
|
||||
is_valid: bool = False
|
||||
error_message: str = ""
|
||||
|
||||
# 元素组成
|
||||
elements: Set[str] = field(default_factory=set)
|
||||
num_sites: int = 0
|
||||
formula: str = ""
|
||||
|
||||
# 阳离子/阴离子信息
|
||||
contains_target_cation: bool = False
|
||||
anion_types: Set[str] = field(default_factory=set)
|
||||
anion_mode: str = "" # "single", "mixed", "none"
|
||||
|
||||
# 数据质量标记
|
||||
has_oxidation_states: bool = False
|
||||
has_partial_occupancy: bool = False # 是否有共占位
|
||||
has_water_molecule: bool = False
|
||||
has_radioactive_elements: bool = False
|
||||
is_binary_compound: bool = False
|
||||
|
||||
# 共占位详细分析(新增)
|
||||
cation_with_vacancy: bool = False # Li与空位共占位(不需处理)
|
||||
cation_with_other_cation: bool = False # Li与其他阳离子共占位(需扩胞)
|
||||
anion_has_partial_occupancy: bool = False # 阴离子共占位
|
||||
other_has_partial_occupancy: bool = False # 其他元素共占位(需扩胞)
|
||||
expansion_info: ExpansionInfo = field(default_factory=ExpansionInfo)
|
||||
|
||||
# 可处理性
|
||||
needs_expansion: bool = False
|
||||
can_process: bool = False
|
||||
skip_reason: str = ""
|
||||
|
||||
|
||||
class StructureInspector:
|
||||
"""结构检查器(含扩胞分析)"""
|
||||
|
||||
# 预定义的阴离子集合
|
||||
VALID_ANIONS = {'O', 'S', 'Cl', 'Br'}
|
||||
|
||||
# 放射性元素
|
||||
RADIOACTIVE_ELEMENTS = {
|
||||
'U', 'Th', 'Pu', 'Ra', 'Rn', 'Po', 'Np', 'Am',
|
||||
'Cm', 'Bk', 'Cf', 'Es', 'Fm', 'Md', 'No', 'Lr'
|
||||
}
|
||||
|
||||
# 扩胞精度模式
|
||||
PRECISION_LIMITS = {
|
||||
'high': None, # 精确分数
|
||||
'normal': 100, # 分母≤100
|
||||
'low': 10, # 分母≤10
|
||||
'very_low': 5 # 分母≤5
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target_cation: str = "Li",
|
||||
target_anions: Set[str] = None,
|
||||
expansion_precision: str = "low"
|
||||
):
|
||||
"""
|
||||
初始化检查器
|
||||
|
||||
Args:
|
||||
target_cation: 目标阳离子 (如 "Li", "Na")
|
||||
target_anions: 目标阴离子集合 (如 {"O", "S"})
|
||||
expansion_precision: 扩胞计算精度 ('high', 'normal', 'low', 'very_low')
|
||||
"""
|
||||
self.target_cation = target_cation
|
||||
self.target_anions = target_anions or self.VALID_ANIONS
|
||||
self.expansion_precision = expansion_precision
|
||||
|
||||
# 目标阳离子的各种可能表示形式
|
||||
self.target_cation_variants = {
|
||||
target_cation,
|
||||
f"{target_cation}+",
|
||||
f"{target_cation}1+",
|
||||
}
|
||||
|
||||
def inspect(self, file_path: str) -> StructureInfo:
|
||||
"""分析单个CIF文件"""
|
||||
import os
|
||||
info = StructureInfo(
|
||||
file_path=file_path,
|
||||
file_name=os.path.basename(file_path)
|
||||
)
|
||||
|
||||
# 尝试读取结构
|
||||
try:
|
||||
structure = Structure.from_file(file_path)
|
||||
except Exception as e:
|
||||
info.is_valid = False
|
||||
info.error_message = f"读取CIF失败: {str(e)}"
|
||||
return info
|
||||
|
||||
info.is_valid = True
|
||||
|
||||
try:
|
||||
# ===== 关键修复:正确提取元素符号 =====
|
||||
# structure.composition.elements 返回 Element 对象列表
|
||||
# 需要用 .symbol 属性获取字符串
|
||||
element_symbols = set()
|
||||
for el in structure.composition.elements:
|
||||
# el 是 Element 对象,el.symbol 是字符串如 "Li"
|
||||
element_symbols.add(el.symbol)
|
||||
|
||||
info.elements = element_symbols
|
||||
info.num_sites = structure.num_sites
|
||||
info.formula = structure.composition.reduced_formula
|
||||
|
||||
# 检查是否为二元化合物
|
||||
info.is_binary_compound = len(element_symbols) == 2
|
||||
|
||||
# ===== 关键修复:直接比较字符串 =====
|
||||
info.contains_target_cation = self.target_cation in element_symbols
|
||||
|
||||
# 检查阴离子类型
|
||||
info.anion_types = element_symbols.intersection(self.target_anions)
|
||||
if len(info.anion_types) == 0:
|
||||
info.anion_mode = "none"
|
||||
elif len(info.anion_types) == 1:
|
||||
info.anion_mode = "single"
|
||||
else:
|
||||
info.anion_mode = "mixed"
|
||||
|
||||
# 检查氧化态
|
||||
info.has_oxidation_states = self._check_oxidation_states(structure)
|
||||
|
||||
# 检查共占位(核心分析)
|
||||
try:
|
||||
self._analyze_partial_occupancy(structure, info)
|
||||
except Exception as e:
|
||||
# 共占位分析失败,记录但继续
|
||||
pass
|
||||
|
||||
# 检查水分子
|
||||
try:
|
||||
info.has_water_molecule = self._check_water_molecule(structure)
|
||||
except:
|
||||
info.has_water_molecule = False
|
||||
|
||||
# 检查放射性元素
|
||||
info.has_radioactive_elements = bool(
|
||||
info.elements.intersection(self.RADIOACTIVE_ELEMENTS)
|
||||
)
|
||||
|
||||
# 判断可处理性
|
||||
self._evaluate_processability(info)
|
||||
|
||||
except Exception as e:
|
||||
# 分析过程出错,但文件本身是有效的
|
||||
# 保留 is_valid = True,但记录错误
|
||||
info.error_message = f"分析过程出错: {str(e)}"
|
||||
|
||||
return info
|
||||
|
||||
def _check_oxidation_states(self, structure: Structure) -> bool:
|
||||
"""检查结构是否包含氧化态信息"""
|
||||
try:
|
||||
for site in structure.sites:
|
||||
for specie in site.species.keys():
|
||||
if isinstance(specie, Specie):
|
||||
return True
|
||||
return False
|
||||
except:
|
||||
return False
|
||||
|
||||
def _get_element_from_species_string(self, species_str: str) -> str:
|
||||
"""从物种字符串提取纯元素符号"""
|
||||
match = re.match(r'([A-Z][a-z]?)', species_str)
|
||||
return match.group(1) if match else ""
|
||||
|
||||
def _get_occupancy_from_species_string(self, species_str: str, exclude_elements: Set[str]) -> Optional[float]:
|
||||
"""
|
||||
从物种字符串获取非目标元素的占据率
|
||||
格式如: "Li+:0.689, Sc3+:0.311"
|
||||
"""
|
||||
if ':' not in species_str:
|
||||
return None
|
||||
|
||||
parts = [p.strip() for p in species_str.split(',')]
|
||||
for part in parts:
|
||||
if ':' in part:
|
||||
element_part, occu_part = part.split(':')
|
||||
element = self._get_element_from_species_string(element_part.strip())
|
||||
if element and element not in exclude_elements:
|
||||
try:
|
||||
return float(occu_part.strip())
|
||||
except ValueError:
|
||||
continue
|
||||
return None
|
||||
|
||||
# 在 StructureInspector 类中,替换 _analyze_partial_occupancy 方法
|
||||
|
||||
def _analyze_partial_occupancy(self, structure: Structure, info: StructureInfo):
|
||||
"""
|
||||
分析共占位情况(修正版)
|
||||
|
||||
关键规则:
|
||||
- Li与空位共占位 → 不需要处理(cation_with_vacancy)
|
||||
- Li与其他阳离子共占位 → 需要扩胞(cation_with_other_cation)
|
||||
- 阴离子共占位 → 通常不处理
|
||||
- 其他阳离子共占位 → 需要扩胞
|
||||
"""
|
||||
occupancy_dict = defaultdict(list) # {occupation: [site_indices]}
|
||||
occupancy_elements = {} # {occupation: [elements]}
|
||||
|
||||
for i, site in enumerate(structure.sites):
|
||||
site_species = site.species
|
||||
species_string = str(site.species)
|
||||
|
||||
# 提取各元素及其占据率
|
||||
species_occu = {} # {element: occupancy}
|
||||
for sp, occu in site_species.items():
|
||||
elem = sp.symbol if hasattr(sp, 'symbol') else str(sp)
|
||||
elem = self._get_element_from_species_string(elem)
|
||||
if elem:
|
||||
species_occu[elem] = occu
|
||||
|
||||
total_occupancy = sum(species_occu.values())
|
||||
elements_at_site = list(species_occu.keys())
|
||||
|
||||
# 检查是否有部分占据
|
||||
has_partial = any(occu < 1.0 for occu in species_occu.values()) or len(species_occu) > 1
|
||||
|
||||
if not has_partial:
|
||||
continue
|
||||
|
||||
info.has_partial_occupancy = True
|
||||
|
||||
# 判断Li的共占位情况
|
||||
if self.target_cation in elements_at_site:
|
||||
li_occu = species_occu.get(self.target_cation, 0)
|
||||
other_elements = [e for e in elements_at_site if e != self.target_cation]
|
||||
|
||||
if not other_elements and li_occu < 1.0:
|
||||
# Li与空位共占位(Li占据率<1,但没有其他元素)
|
||||
info.cation_with_vacancy = True
|
||||
elif other_elements:
|
||||
# Li与其他元素共占位
|
||||
other_are_anions = all(e in self.target_anions for e in other_elements)
|
||||
if other_are_anions:
|
||||
# Li与阴离子共占位(罕见,标记为阴离子共占位)
|
||||
info.anion_has_partial_occupancy = True
|
||||
else:
|
||||
# Li与其他阳离子共占位 → 需要扩胞
|
||||
info.cation_with_other_cation = True
|
||||
|
||||
# 记录需要扩胞的占据率(取非Li元素的占据率)
|
||||
for elem in other_elements:
|
||||
if elem not in self.target_anions:
|
||||
occu = species_occu.get(elem, 0)
|
||||
if occu > 0 and occu < 1.0:
|
||||
occupancy_dict[occu].append(i)
|
||||
occupancy_elements[occu] = elements_at_site
|
||||
else:
|
||||
# 不涉及Li的位点
|
||||
# 判断是否涉及阴离子
|
||||
if any(elem in self.target_anions for elem in elements_at_site):
|
||||
info.anion_has_partial_occupancy = True
|
||||
else:
|
||||
# 其他阳离子的共占位 → 需要扩胞
|
||||
info.other_has_partial_occupancy = True
|
||||
|
||||
# 获取占据率
|
||||
for elem, occu in species_occu.items():
|
||||
if occu > 0 and occu < 1.0:
|
||||
occupancy_dict[occu].append(i)
|
||||
occupancy_elements[occu] = elements_at_site
|
||||
break # 只记录一次
|
||||
|
||||
# 计算扩胞信息
|
||||
self._calculate_expansion_info(info, occupancy_dict, occupancy_elements)
|
||||
|
||||
def _evaluate_processability(self, info: StructureInfo):
|
||||
"""评估可处理性(修正版)"""
|
||||
skip_reasons = []
|
||||
|
||||
if not info.is_valid:
|
||||
skip_reasons.append("无法解析CIF文件")
|
||||
|
||||
if not info.contains_target_cation:
|
||||
skip_reasons.append(f"不含{self.target_cation}")
|
||||
|
||||
if info.anion_mode == "none":
|
||||
skip_reasons.append("不含目标阴离子")
|
||||
|
||||
if info.is_binary_compound:
|
||||
skip_reasons.append("二元化合物")
|
||||
|
||||
if info.has_radioactive_elements:
|
||||
skip_reasons.append("含放射性元素")
|
||||
|
||||
# Li与空位共占位 → 不需要处理(不加入skip_reasons)
|
||||
# info.cation_with_vacancy 不影响可处理性
|
||||
|
||||
# Li与其他阳离子共占位 → 需要扩胞(如果扩胞因子合理则可处理)
|
||||
if info.cation_with_other_cation:
|
||||
if info.expansion_info.can_expand:
|
||||
info.needs_expansion = True
|
||||
else:
|
||||
skip_reasons.append(f"{self.target_cation}与其他阳离子共占位且{info.expansion_info.skip_reason}")
|
||||
|
||||
# 阴离子共占位 → 不处理
|
||||
if info.anion_has_partial_occupancy:
|
||||
skip_reasons.append("阴离子存在共占位")
|
||||
|
||||
if info.has_water_molecule:
|
||||
skip_reasons.append("含水分子")
|
||||
|
||||
# 其他阳离子共占位(不涉及Li)→ 需要扩胞
|
||||
if info.other_has_partial_occupancy:
|
||||
if info.expansion_info.can_expand:
|
||||
info.needs_expansion = True
|
||||
else:
|
||||
skip_reasons.append(info.expansion_info.skip_reason)
|
||||
|
||||
if skip_reasons:
|
||||
info.can_process = False
|
||||
info.skip_reason = "; ".join(skip_reasons)
|
||||
else:
|
||||
info.can_process = True
|
||||
|
||||
def _calculate_expansion_info(
|
||||
self,
|
||||
info: StructureInfo,
|
||||
occupancy_dict: Dict[float, List[int]],
|
||||
occupancy_elements: Dict[float, List[str]]
|
||||
):
|
||||
"""计算扩胞相关信息"""
|
||||
expansion_info = ExpansionInfo()
|
||||
|
||||
if not occupancy_dict:
|
||||
info.expansion_info = expansion_info
|
||||
return
|
||||
|
||||
# 需要扩胞(有非目标阳离子的共占位)
|
||||
expansion_info.needs_expansion = True
|
||||
expansion_info.problematic_sites = sum(len(v) for v in occupancy_dict.values())
|
||||
|
||||
# 转换为OccupancyInfo列表
|
||||
occupancy_list = []
|
||||
for occu, serials in occupancy_dict.items():
|
||||
elements = occupancy_elements.get(occu, [])
|
||||
|
||||
# 根据精度计算分数
|
||||
limit = self.PRECISION_LIMITS.get(self.expansion_precision)
|
||||
if limit:
|
||||
fraction = Fraction(occu).limit_denominator(limit)
|
||||
else:
|
||||
fraction = Fraction(occu).limit_denominator()
|
||||
|
||||
occ_info = OccupancyInfo(
|
||||
occupation=occu,
|
||||
atom_serials=[s + 1 for s in serials], # 转为1-based
|
||||
elements=elements,
|
||||
numerator=fraction.numerator,
|
||||
denominator=fraction.denominator,
|
||||
involves_target_cation=self.target_cation in elements,
|
||||
involves_anion=any(e in self.target_anions for e in elements)
|
||||
)
|
||||
occupancy_list.append(occ_info)
|
||||
|
||||
expansion_info.occupancy_details = occupancy_list
|
||||
|
||||
# 计算最小公倍数(扩胞因子)
|
||||
denominators = [occ.denominator for occ in occupancy_list]
|
||||
if denominators:
|
||||
lcm = reduce(lambda a, b: a * b // math.gcd(a, b), denominators, 1)
|
||||
expansion_info.expansion_factor = lcm
|
||||
|
||||
# 判断是否可以扩胞(因子过大则不可处理)
|
||||
if lcm > 64: # 扩胞超过64倍通常不可行
|
||||
expansion_info.can_expand = False
|
||||
expansion_info.skip_reason = f"扩胞因子过大({lcm})"
|
||||
|
||||
info.expansion_info = expansion_info
|
||||
info.needs_expansion = expansion_info.needs_expansion and expansion_info.can_expand
|
||||
|
||||
def _check_water_molecule(self, structure: Structure) -> bool:
|
||||
"""检查是否含有水分子"""
|
||||
try:
|
||||
oxygen_sites = []
|
||||
hydrogen_sites = []
|
||||
|
||||
for site in structure.sites:
|
||||
species_str = str(site.species)
|
||||
if 'O' in species_str:
|
||||
oxygen_sites.append(site)
|
||||
if 'H' in species_str:
|
||||
hydrogen_sites.append(site)
|
||||
|
||||
for o_site in oxygen_sites:
|
||||
nearby_h = [h for h in hydrogen_sites if o_site.distance(h) < 1.2]
|
||||
if len(nearby_h) >= 2:
|
||||
return True
|
||||
return False
|
||||
except:
|
||||
return False
|
||||
199
src/analysis/worker.py
Normal file
199
src/analysis/worker.py
Normal file
@@ -0,0 +1,199 @@
|
||||
"""
|
||||
工作进程:处理单个分析任务
|
||||
设计为可以独立运行(用于SLURM作业数组)
|
||||
"""
|
||||
import os
|
||||
import pickle
|
||||
from typing import List, Tuple, Optional
|
||||
from dataclasses import asdict, fields
|
||||
|
||||
from .structure_inspector import StructureInspector, StructureInfo
|
||||
|
||||
|
||||
def analyze_single_file(args: Tuple[str, str, set]) -> Optional[StructureInfo]:
|
||||
"""
|
||||
分析单个CIF文件(Worker函数)
|
||||
|
||||
Args:
|
||||
args: (file_path, target_cation, target_anions)
|
||||
|
||||
Returns:
|
||||
StructureInfo 或 None(如果分析失败)
|
||||
"""
|
||||
file_path, target_cation, target_anions = args
|
||||
|
||||
try:
|
||||
inspector = StructureInspector(
|
||||
target_cation=target_cation,
|
||||
target_anions=target_anions
|
||||
)
|
||||
result = inspector.inspect(file_path)
|
||||
return result
|
||||
except Exception as e:
|
||||
# 返回一个标记失败的结果(而不是 None)
|
||||
return StructureInfo(
|
||||
file_path=file_path,
|
||||
file_name=os.path.basename(file_path),
|
||||
is_valid=False,
|
||||
error_message=f"Worker异常: {str(e)}"
|
||||
)
|
||||
|
||||
def structure_info_to_dict(info: StructureInfo) -> dict:
|
||||
"""
|
||||
将 StructureInfo 转换为可序列化的字典
|
||||
处理 set、dataclass 等特殊类型
|
||||
"""
|
||||
result = {}
|
||||
for field in fields(info):
|
||||
value = getattr(info, field.name)
|
||||
|
||||
# 处理 set 类型
|
||||
if isinstance(value, set):
|
||||
result[field.name] = list(value)
|
||||
# 处理嵌套的 dataclass (如 ExpansionInfo)
|
||||
elif hasattr(value, '__dataclass_fields__'):
|
||||
result[field.name] = asdict(value)
|
||||
# 处理 list 中可能包含的 dataclass
|
||||
elif isinstance(value, list):
|
||||
result[field.name] = [
|
||||
asdict(item) if hasattr(item, '__dataclass_fields__') else item
|
||||
for item in value
|
||||
]
|
||||
else:
|
||||
result[field.name] = value
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def dict_to_structure_info(d: dict) -> StructureInfo:
|
||||
"""
|
||||
从字典恢复 StructureInfo 对象
|
||||
"""
|
||||
from .structure_inspector import ExpansionInfo, OccupancyInfo
|
||||
|
||||
# 处理 set 类型字段
|
||||
if 'elements' in d and isinstance(d['elements'], list):
|
||||
d['elements'] = set(d['elements'])
|
||||
if 'anion_types' in d and isinstance(d['anion_types'], list):
|
||||
d['anion_types'] = set(d['anion_types'])
|
||||
if 'target_anions' in d and isinstance(d['target_anions'], list):
|
||||
d['target_anions'] = set(d['target_anions'])
|
||||
|
||||
# 处理 ExpansionInfo
|
||||
if 'expansion_info' in d and isinstance(d['expansion_info'], dict):
|
||||
exp_dict = d['expansion_info']
|
||||
|
||||
# 处理 OccupancyInfo 列表
|
||||
if 'occupancy_details' in exp_dict:
|
||||
exp_dict['occupancy_details'] = [
|
||||
OccupancyInfo(**occ) if isinstance(occ, dict) else occ
|
||||
for occ in exp_dict['occupancy_details']
|
||||
]
|
||||
|
||||
d['expansion_info'] = ExpansionInfo(**exp_dict)
|
||||
|
||||
return StructureInfo(**d)
|
||||
|
||||
|
||||
def batch_analyze(
|
||||
file_paths: List[str],
|
||||
target_cation: str,
|
||||
target_anions: set,
|
||||
output_file: str = None
|
||||
) -> List[StructureInfo]:
|
||||
"""
|
||||
批量分析文件(用于SLURM子任务)
|
||||
|
||||
Args:
|
||||
file_paths: CIF文件路径列表
|
||||
target_cation: 目标阳离子
|
||||
target_anions: 目标阴离子集合
|
||||
output_file: 输出文件路径(pickle格式)
|
||||
|
||||
Returns:
|
||||
StructureInfo列表
|
||||
"""
|
||||
results = []
|
||||
|
||||
inspector = StructureInspector(
|
||||
target_cation=target_cation,
|
||||
target_anions=target_anions
|
||||
)
|
||||
|
||||
for file_path in file_paths:
|
||||
try:
|
||||
info = inspector.inspect(file_path)
|
||||
results.append(info)
|
||||
except Exception as e:
|
||||
results.append(StructureInfo(
|
||||
file_path=file_path,
|
||||
file_name=os.path.basename(file_path),
|
||||
is_valid=False,
|
||||
error_message=str(e)
|
||||
))
|
||||
|
||||
# 保存结果
|
||||
if output_file:
|
||||
serializable_results = [structure_info_to_dict(r) for r in results]
|
||||
with open(output_file, 'wb') as f:
|
||||
pickle.dump(serializable_results, f)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def load_results(result_file: str) -> List[StructureInfo]:
|
||||
"""
|
||||
从pickle文件加载结果
|
||||
"""
|
||||
with open(result_file, 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
|
||||
return [dict_to_structure_info(d) for d in data]
|
||||
|
||||
|
||||
def merge_results(result_files: List[str]) -> List[StructureInfo]:
|
||||
"""
|
||||
合并多个结果文件(用于汇总SLURM作业数组的输出)
|
||||
"""
|
||||
all_results = []
|
||||
for f in result_files:
|
||||
if os.path.exists(f):
|
||||
all_results.extend(load_results(f))
|
||||
return all_results
|
||||
|
||||
|
||||
# 用于SLURM作业数组的命令行入口
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
import json
|
||||
|
||||
parser = argparse.ArgumentParser(description="CIF Analysis Worker")
|
||||
parser.add_argument("--tasks-file", required=True, help="任务文件路径(JSON)")
|
||||
parser.add_argument("--output-dir", required=True, help="输出目录")
|
||||
parser.add_argument("--task-id", type=int, default=0, help="任务ID(用于数组作业)")
|
||||
parser.add_argument("--num-workers", type=int, default=1, help="并行worker数")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 加载任务
|
||||
with open(args.tasks_file, 'r') as f:
|
||||
task_config = json.load(f)
|
||||
|
||||
file_paths = task_config['files']
|
||||
target_cation = task_config['target_cation']
|
||||
target_anions = set(task_config['target_anions'])
|
||||
|
||||
# 如果是数组作业,只处理分配的部分
|
||||
if args.task_id >= 0:
|
||||
chunk_size = len(file_paths) // args.num_workers + 1
|
||||
start_idx = args.task_id * chunk_size
|
||||
end_idx = min(start_idx + chunk_size, len(file_paths))
|
||||
file_paths = file_paths[start_idx:end_idx]
|
||||
|
||||
# 输出文件
|
||||
output_file = os.path.join(args.output_dir, f"results_{args.task_id}.pkl")
|
||||
|
||||
# 执行分析
|
||||
print(f"Worker {args.task_id}: 处理 {len(file_paths)} 个文件")
|
||||
results = batch_analyze(file_paths, target_cation, target_anions, output_file)
|
||||
print(f"Worker {args.task_id}: 完成,结果保存到 {output_file}")
|
||||
15
src/computation/__init__.py
Normal file
15
src/computation/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
计算模块:Zeo++ Voronoi 分析
|
||||
"""
|
||||
from .workspace_manager import WorkspaceManager
|
||||
from .zeo_executor import ZeoExecutor, ZeoConfig
|
||||
from .result_processor import ResultProcessor, FilterCriteria, StructureResult
|
||||
|
||||
__all__ = [
|
||||
'WorkspaceManager',
|
||||
'ZeoExecutor',
|
||||
'ZeoConfig',
|
||||
'ResultProcessor',
|
||||
'FilterCriteria',
|
||||
'StructureResult'
|
||||
]
|
||||
426
src/computation/result_processor.py
Normal file
426
src/computation/result_processor.py
Normal file
@@ -0,0 +1,426 @@
|
||||
"""
|
||||
Zeo++ 计算结果处理器:提取数据、筛选结构
|
||||
"""
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
import pandas as pd
|
||||
|
||||
|
||||
@dataclass
|
||||
class FilterCriteria:
|
||||
"""筛选条件"""
|
||||
min_percolation_diameter: float = 1.0 # 最小渗透直径 (Å),默认 1.0
|
||||
min_d_value: float = 2.0 # 最小 d 值,默认 2.0
|
||||
max_node_length: float = float('inf') # 最大节点长度 (Å)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StructureResult:
|
||||
"""单个结构的计算结果"""
|
||||
structure_name: str
|
||||
anion_type: str
|
||||
work_dir: str
|
||||
|
||||
# 提取的参数
|
||||
percolation_diameter: Optional[float] = None
|
||||
min_d: Optional[float] = None
|
||||
max_node_length: Optional[float] = None
|
||||
|
||||
# 筛选结果
|
||||
passed_filter: bool = False
|
||||
filter_reason: str = ""
|
||||
|
||||
|
||||
class ResultProcessor:
|
||||
"""
|
||||
Zeo++ 计算结果处理器
|
||||
|
||||
功能:
|
||||
1. 从每个结构目录的 log.txt 提取关键参数
|
||||
2. 汇总所有结果到 CSV 文件
|
||||
3. 根据筛选条件筛选结构
|
||||
4. 将通过筛选的结构复制到新文件夹
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace_path: str = "workspace",
|
||||
data_dir: str = None,
|
||||
output_dir: str = None
|
||||
):
|
||||
"""
|
||||
初始化结果处理器
|
||||
|
||||
Args:
|
||||
workspace_path: 工作区根目录
|
||||
data_dir: 数据目录(默认 workspace/data)
|
||||
output_dir: 输出目录(默认 workspace/results)
|
||||
"""
|
||||
self.workspace_path = os.path.abspath(workspace_path)
|
||||
self.data_dir = data_dir or os.path.join(self.workspace_path, "data")
|
||||
self.output_dir = output_dir or os.path.join(self.workspace_path, "results")
|
||||
|
||||
def extract_from_log(self, log_path: str) -> Tuple[Optional[float], Optional[float], Optional[float]]:
|
||||
"""
|
||||
从 log.txt 中提取三个关键参数
|
||||
|
||||
Args:
|
||||
log_path: log.txt 文件路径
|
||||
|
||||
Returns:
|
||||
(percolation_diameter, min_d, max_node_length)
|
||||
"""
|
||||
if not os.path.exists(log_path):
|
||||
return None, None, None
|
||||
|
||||
try:
|
||||
with open(log_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
except Exception:
|
||||
return None, None, None
|
||||
|
||||
# 正则表达式 - 与 py/extract_data.py 保持一致
|
||||
|
||||
# 1. Percolation diameter: "# Percolation diameter (A): 1.06"
|
||||
re_percolation = r"Percolation diameter \(A\):\s*([\d\.]+)"
|
||||
|
||||
# 2. Minimum of d: "the minium of d\n3.862140561244235"
|
||||
# 注意:这是 Topological_Analysis 库输出的格式
|
||||
re_min_d = r"the minium of d\s*\n\s*([\d\.]+)"
|
||||
|
||||
# 3. Maximum node length: "# Maximum node length detected: 1.332 A"
|
||||
re_max_node = r"Maximum node length detected:\s*([\d\.]+)\s*A"
|
||||
|
||||
# 提取数据
|
||||
match_perc = re.search(re_percolation, content)
|
||||
match_d = re.search(re_min_d, content)
|
||||
match_node = re.search(re_max_node, content)
|
||||
|
||||
val_perc = float(match_perc.group(1)) if match_perc else None
|
||||
val_d = float(match_d.group(1)) if match_d else None
|
||||
val_node = float(match_node.group(1)) if match_node else None
|
||||
|
||||
return val_perc, val_d, val_node
|
||||
|
||||
def process_all_structures(self) -> List[StructureResult]:
|
||||
"""
|
||||
处理所有结构,提取计算结果
|
||||
|
||||
Returns:
|
||||
StructureResult 列表
|
||||
"""
|
||||
results = []
|
||||
|
||||
if not os.path.exists(self.data_dir):
|
||||
print(f"⚠️ 数据目录不存在: {self.data_dir}")
|
||||
return results
|
||||
|
||||
print("\n正在提取计算结果...")
|
||||
|
||||
# 遍历阴离子目录
|
||||
for anion_key in os.listdir(self.data_dir):
|
||||
anion_dir = os.path.join(self.data_dir, anion_key)
|
||||
if not os.path.isdir(anion_dir):
|
||||
continue
|
||||
|
||||
# 遍历结构目录
|
||||
for struct_name in os.listdir(anion_dir):
|
||||
struct_dir = os.path.join(anion_dir, struct_name)
|
||||
if not os.path.isdir(struct_dir):
|
||||
continue
|
||||
|
||||
# 查找 log.txt
|
||||
log_path = os.path.join(struct_dir, "log.txt")
|
||||
|
||||
# 提取参数
|
||||
perc, min_d, max_node = self.extract_from_log(log_path)
|
||||
|
||||
result = StructureResult(
|
||||
structure_name=struct_name,
|
||||
anion_type=anion_key,
|
||||
work_dir=struct_dir,
|
||||
percolation_diameter=perc,
|
||||
min_d=min_d,
|
||||
max_node_length=max_node
|
||||
)
|
||||
|
||||
results.append(result)
|
||||
|
||||
print(f" 共处理 {len(results)} 个结构")
|
||||
return results
|
||||
|
||||
def apply_filter(
|
||||
self,
|
||||
results: List[StructureResult],
|
||||
criteria: FilterCriteria
|
||||
) -> List[StructureResult]:
|
||||
"""
|
||||
应用筛选条件
|
||||
|
||||
Args:
|
||||
results: 结构结果列表
|
||||
criteria: 筛选条件
|
||||
|
||||
Returns:
|
||||
更新后的结果列表(包含筛选状态)
|
||||
"""
|
||||
print("\n应用筛选条件...")
|
||||
print(f" 最小渗透直径: {criteria.min_percolation_diameter} Å")
|
||||
print(f" 最小 d 值: {criteria.min_d_value}")
|
||||
print(f" 最大节点长度: {criteria.max_node_length} Å")
|
||||
|
||||
passed_count = 0
|
||||
|
||||
for result in results:
|
||||
# 检查是否有有效数据
|
||||
if result.percolation_diameter is None or result.min_d is None:
|
||||
result.passed_filter = False
|
||||
result.filter_reason = "数据缺失"
|
||||
continue
|
||||
|
||||
# 检查渗透直径
|
||||
if result.percolation_diameter < criteria.min_percolation_diameter:
|
||||
result.passed_filter = False
|
||||
result.filter_reason = f"渗透直径 {result.percolation_diameter:.3f} < {criteria.min_percolation_diameter}"
|
||||
continue
|
||||
|
||||
# 检查 d 值
|
||||
if result.min_d < criteria.min_d_value:
|
||||
result.passed_filter = False
|
||||
result.filter_reason = f"d 值 {result.min_d:.3f} < {criteria.min_d_value}"
|
||||
continue
|
||||
|
||||
# 检查节点长度(如果有数据)
|
||||
if result.max_node_length is not None:
|
||||
if result.max_node_length > criteria.max_node_length:
|
||||
result.passed_filter = False
|
||||
result.filter_reason = f"节点长度 {result.max_node_length:.3f} > {criteria.max_node_length}"
|
||||
continue
|
||||
|
||||
# 通过所有筛选
|
||||
result.passed_filter = True
|
||||
result.filter_reason = "通过"
|
||||
passed_count += 1
|
||||
|
||||
print(f" 通过筛选: {passed_count}/{len(results)}")
|
||||
return results
|
||||
|
||||
def save_summary_csv(
|
||||
self,
|
||||
results: List[StructureResult],
|
||||
output_path: str = None
|
||||
) -> str:
|
||||
"""
|
||||
保存汇总 CSV 文件
|
||||
|
||||
Args:
|
||||
results: 结构结果列表
|
||||
output_path: 输出路径(默认 workspace/results/summary.csv)
|
||||
|
||||
Returns:
|
||||
CSV 文件路径
|
||||
"""
|
||||
if output_path is None:
|
||||
os.makedirs(self.output_dir, exist_ok=True)
|
||||
output_path = os.path.join(self.output_dir, "summary.csv")
|
||||
|
||||
# 构建数据
|
||||
data = []
|
||||
for r in results:
|
||||
data.append({
|
||||
'Structure': r.structure_name,
|
||||
'Anion_Type': r.anion_type,
|
||||
'Percolation_Diameter_A': r.percolation_diameter,
|
||||
'Min_d': r.min_d,
|
||||
'Max_Node_Length_A': r.max_node_length,
|
||||
'Passed_Filter': r.passed_filter,
|
||||
'Filter_Reason': r.filter_reason
|
||||
})
|
||||
|
||||
df = pd.DataFrame(data)
|
||||
|
||||
# 按阴离子类型和结构名排序
|
||||
df = df.sort_values(['Anion_Type', 'Structure'])
|
||||
|
||||
# 保存
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
df.to_csv(output_path, index=False)
|
||||
|
||||
print(f"\n汇总 CSV 已保存: {output_path}")
|
||||
return output_path
|
||||
|
||||
def save_anion_csv(
|
||||
self,
|
||||
results: List[StructureResult],
|
||||
output_dir: str = None
|
||||
) -> List[str]:
|
||||
"""
|
||||
按阴离子类型分别保存 CSV 文件
|
||||
|
||||
Args:
|
||||
results: 结构结果列表
|
||||
output_dir: 输出目录
|
||||
|
||||
Returns:
|
||||
生成的 CSV 文件路径列表
|
||||
"""
|
||||
if output_dir is None:
|
||||
output_dir = self.output_dir
|
||||
|
||||
# 按阴离子类型分组
|
||||
anion_groups: Dict[str, List[StructureResult]] = {}
|
||||
for r in results:
|
||||
if r.anion_type not in anion_groups:
|
||||
anion_groups[r.anion_type] = []
|
||||
anion_groups[r.anion_type].append(r)
|
||||
|
||||
csv_files = []
|
||||
|
||||
for anion_type, group_results in anion_groups.items():
|
||||
# 构建数据
|
||||
data = []
|
||||
for r in group_results:
|
||||
data.append({
|
||||
'Structure': r.structure_name,
|
||||
'Percolation_Diameter_A': r.percolation_diameter,
|
||||
'Min_d': r.min_d,
|
||||
'Max_Node_Length_A': r.max_node_length,
|
||||
'Passed_Filter': r.passed_filter,
|
||||
'Filter_Reason': r.filter_reason
|
||||
})
|
||||
|
||||
df = pd.DataFrame(data)
|
||||
df = df.sort_values('Structure')
|
||||
|
||||
# 保存到对应目录
|
||||
anion_output_dir = os.path.join(output_dir, anion_type)
|
||||
os.makedirs(anion_output_dir, exist_ok=True)
|
||||
|
||||
csv_path = os.path.join(anion_output_dir, f"{anion_type}.csv")
|
||||
df.to_csv(csv_path, index=False)
|
||||
csv_files.append(csv_path)
|
||||
|
||||
print(f" {anion_type}: {len(group_results)} 个结构 -> {csv_path}")
|
||||
|
||||
return csv_files
|
||||
|
||||
def copy_passed_structures(
|
||||
self,
|
||||
results: List[StructureResult],
|
||||
output_dir: str = None
|
||||
) -> int:
|
||||
"""
|
||||
将通过筛选的结构复制到新文件夹
|
||||
|
||||
Args:
|
||||
results: 结构结果列表
|
||||
output_dir: 输出目录(默认 workspace/passed)
|
||||
|
||||
Returns:
|
||||
复制的结构数量
|
||||
"""
|
||||
if output_dir is None:
|
||||
output_dir = os.path.join(self.workspace_path, "passed")
|
||||
|
||||
passed_results = [r for r in results if r.passed_filter]
|
||||
|
||||
if not passed_results:
|
||||
print("\n没有通过筛选的结构")
|
||||
return 0
|
||||
|
||||
print(f"\n正在复制 {len(passed_results)} 个通过筛选的结构...")
|
||||
|
||||
copied = 0
|
||||
for r in passed_results:
|
||||
# 目标目录:passed/阴离子类型/结构名/
|
||||
dst_dir = os.path.join(output_dir, r.anion_type, r.structure_name)
|
||||
|
||||
try:
|
||||
# 如果目标已存在,先删除
|
||||
if os.path.exists(dst_dir):
|
||||
shutil.rmtree(dst_dir)
|
||||
|
||||
# 复制整个目录
|
||||
shutil.copytree(r.work_dir, dst_dir)
|
||||
copied += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f" ⚠️ 复制失败 {r.structure_name}: {e}")
|
||||
|
||||
print(f" 已复制 {copied} 个结构到: {output_dir}")
|
||||
return copied
|
||||
|
||||
def process_and_filter(
|
||||
self,
|
||||
criteria: FilterCriteria = None,
|
||||
save_csv: bool = True,
|
||||
copy_passed: bool = True
|
||||
) -> Tuple[List[StructureResult], Dict]:
|
||||
"""
|
||||
完整的处理流程:提取数据 -> 筛选 -> 保存 CSV -> 复制通过的结构
|
||||
|
||||
Args:
|
||||
criteria: 筛选条件(如果为 None,则不筛选)
|
||||
save_csv: 是否保存 CSV
|
||||
copy_passed: 是否复制通过筛选的结构
|
||||
|
||||
Returns:
|
||||
(结果列表, 统计信息字典)
|
||||
"""
|
||||
# 1. 提取所有结构的计算结果
|
||||
results = self.process_all_structures()
|
||||
|
||||
if not results:
|
||||
return results, {'total': 0, 'passed': 0, 'failed': 0}
|
||||
|
||||
# 2. 应用筛选条件
|
||||
if criteria is not None:
|
||||
results = self.apply_filter(results, criteria)
|
||||
|
||||
# 3. 保存 CSV
|
||||
if save_csv:
|
||||
print("\n保存结果 CSV...")
|
||||
self.save_summary_csv(results)
|
||||
self.save_anion_csv(results)
|
||||
|
||||
# 4. 复制通过筛选的结构
|
||||
if copy_passed and criteria is not None:
|
||||
self.copy_passed_structures(results)
|
||||
|
||||
# 统计
|
||||
stats = {
|
||||
'total': len(results),
|
||||
'passed': sum(1 for r in results if r.passed_filter),
|
||||
'failed': sum(1 for r in results if not r.passed_filter),
|
||||
'missing_data': sum(1 for r in results if r.filter_reason == "数据缺失")
|
||||
}
|
||||
|
||||
return results, stats
|
||||
|
||||
def print_summary(self, results: List[StructureResult], stats: Dict):
|
||||
"""打印结果摘要"""
|
||||
print("\n" + "=" * 60)
|
||||
print("【计算结果摘要】")
|
||||
print("=" * 60)
|
||||
print(f" 总结构数: {stats['total']}")
|
||||
print(f" 通过筛选: {stats['passed']}")
|
||||
print(f" 未通过筛选: {stats['failed']}")
|
||||
print(f" 数据缺失: {stats.get('missing_data', 0)}")
|
||||
|
||||
# 按阴离子类型统计
|
||||
anion_stats: Dict[str, Dict] = {}
|
||||
for r in results:
|
||||
if r.anion_type not in anion_stats:
|
||||
anion_stats[r.anion_type] = {'total': 0, 'passed': 0}
|
||||
anion_stats[r.anion_type]['total'] += 1
|
||||
if r.passed_filter:
|
||||
anion_stats[r.anion_type]['passed'] += 1
|
||||
|
||||
print("\n 按阴离子类型:")
|
||||
for anion, s in sorted(anion_stats.items()):
|
||||
print(f" {anion}: {s['passed']}/{s['total']} 通过")
|
||||
|
||||
print("=" * 60)
|
||||
288
src/computation/workspace_manager.py
Normal file
288
src/computation/workspace_manager.py
Normal file
@@ -0,0 +1,288 @@
|
||||
"""
|
||||
工作区管理器:管理计算工作区的创建和软链接
|
||||
"""
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkspaceInfo:
|
||||
"""工作区信息"""
|
||||
workspace_path: str
|
||||
data_dir: str # workspace/data
|
||||
tool_dir: str # tool 目录
|
||||
target_cation: str
|
||||
target_anions: Set[str]
|
||||
|
||||
# 统计信息
|
||||
total_structures: int = 0
|
||||
anion_counts: Dict[str, int] = field(default_factory=dict)
|
||||
linked_structures: int = 0 # 已创建软链接的结构数
|
||||
|
||||
|
||||
class WorkspaceManager:
|
||||
"""
|
||||
工作区管理器
|
||||
|
||||
负责:
|
||||
1. 检测现有工作区
|
||||
2. 创建软链接(yaml 配置文件和计算脚本放在每个结构目录下)
|
||||
3. 准备计算任务
|
||||
"""
|
||||
|
||||
# 支持的阴离子及其配置文件
|
||||
SUPPORTED_ANIONS = {'O', 'S', 'Cl', 'Br'}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace_path: str = "workspace",
|
||||
tool_dir: str = "tool",
|
||||
target_cation: str = "Li"
|
||||
):
|
||||
"""
|
||||
初始化工作区管理器
|
||||
|
||||
Args:
|
||||
workspace_path: 工作区根目录
|
||||
tool_dir: 工具目录(包含 yaml 配置和计算脚本)
|
||||
target_cation: 目标阳离子
|
||||
"""
|
||||
self.workspace_path = os.path.abspath(workspace_path)
|
||||
self.tool_dir = os.path.abspath(tool_dir)
|
||||
self.target_cation = target_cation
|
||||
|
||||
# 数据目录
|
||||
self.data_dir = os.path.join(self.workspace_path, "data")
|
||||
|
||||
def check_existing_workspace(self) -> Optional[WorkspaceInfo]:
|
||||
"""
|
||||
检查现有工作区
|
||||
|
||||
Returns:
|
||||
WorkspaceInfo 如果存在,否则 None
|
||||
"""
|
||||
if not os.path.exists(self.data_dir):
|
||||
return None
|
||||
|
||||
# 扫描数据目录
|
||||
anion_counts = {}
|
||||
total = 0
|
||||
linked = 0
|
||||
|
||||
for item in os.listdir(self.data_dir):
|
||||
item_path = os.path.join(self.data_dir, item)
|
||||
if os.path.isdir(item_path):
|
||||
# 可能是阴离子目录(如 O, S, O+S)
|
||||
# 统计其中的结构数量
|
||||
count = 0
|
||||
for sub_item in os.listdir(item_path):
|
||||
sub_path = os.path.join(item_path, sub_item)
|
||||
if os.path.isdir(sub_path):
|
||||
# 检查是否包含 CIF 文件
|
||||
cif_files = [f for f in os.listdir(sub_path) if f.endswith('.cif')]
|
||||
if cif_files:
|
||||
count += 1
|
||||
# 检查是否已有软链接
|
||||
yaml_files = [f for f in os.listdir(sub_path) if f.endswith('.yaml')]
|
||||
if yaml_files:
|
||||
linked += 1
|
||||
|
||||
if count > 0:
|
||||
anion_counts[item] = count
|
||||
total += count
|
||||
|
||||
if total == 0:
|
||||
return None
|
||||
|
||||
return WorkspaceInfo(
|
||||
workspace_path=self.workspace_path,
|
||||
data_dir=self.data_dir,
|
||||
tool_dir=self.tool_dir,
|
||||
target_cation=self.target_cation,
|
||||
target_anions=set(anion_counts.keys()),
|
||||
total_structures=total,
|
||||
anion_counts=anion_counts,
|
||||
linked_structures=linked
|
||||
)
|
||||
|
||||
def setup_workspace(
|
||||
self,
|
||||
target_anions: Set[str] = None,
|
||||
force_relink: bool = False
|
||||
) -> WorkspaceInfo:
|
||||
"""
|
||||
设置工作区:在每个结构目录下创建软链接
|
||||
|
||||
软链接规则:
|
||||
- yaml 文件:使用与阴离子目录同名的 yaml(如 O 目录用 O.yaml,Cl+O 目录用 Cl+O.yaml)
|
||||
- python 脚本:analyze_voronoi_nodes.py
|
||||
|
||||
Args:
|
||||
target_anions: 目标阴离子集合
|
||||
force_relink: 是否强制重新创建软链接
|
||||
|
||||
Returns:
|
||||
WorkspaceInfo
|
||||
"""
|
||||
if target_anions is None:
|
||||
target_anions = self.SUPPORTED_ANIONS
|
||||
|
||||
# 确保数据目录存在
|
||||
if not os.path.exists(self.data_dir):
|
||||
raise FileNotFoundError(f"数据目录不存在: {self.data_dir}")
|
||||
|
||||
# 获取计算脚本路径
|
||||
analyze_script = os.path.join(self.tool_dir, "analyze_voronoi_nodes.py")
|
||||
if not os.path.exists(analyze_script):
|
||||
raise FileNotFoundError(f"计算脚本不存在: {analyze_script}")
|
||||
|
||||
anion_counts = {}
|
||||
total = 0
|
||||
linked = 0
|
||||
|
||||
print("\n正在设置工作区软链接...")
|
||||
|
||||
# 遍历数据目录中的阴离子子目录
|
||||
for anion_key in os.listdir(self.data_dir):
|
||||
anion_dir = os.path.join(self.data_dir, anion_key)
|
||||
if not os.path.isdir(anion_dir):
|
||||
continue
|
||||
|
||||
# 确定使用哪个 yaml 配置文件
|
||||
# 使用与阴离子目录同名的 yaml 文件(如 O.yaml, Cl+O.yaml)
|
||||
yaml_name = f"{anion_key}.yaml"
|
||||
yaml_source = os.path.join(self.tool_dir, self.target_cation, yaml_name)
|
||||
|
||||
if not os.path.exists(yaml_source):
|
||||
print(f" ⚠️ 配置文件不存在: {yaml_source}")
|
||||
continue
|
||||
|
||||
# 统计并处理该阴离子目录下的所有结构
|
||||
count = 0
|
||||
for struct_name in os.listdir(anion_dir):
|
||||
struct_dir = os.path.join(anion_dir, struct_name)
|
||||
|
||||
if not os.path.isdir(struct_dir):
|
||||
continue
|
||||
|
||||
# 检查是否包含 CIF 文件
|
||||
cif_files = [f for f in os.listdir(struct_dir) if f.endswith('.cif')]
|
||||
if not cif_files:
|
||||
continue
|
||||
|
||||
count += 1
|
||||
|
||||
# 在结构目录下创建软链接
|
||||
yaml_link = os.path.join(struct_dir, yaml_name)
|
||||
script_link = os.path.join(struct_dir, "analyze_voronoi_nodes.py")
|
||||
|
||||
# 创建 yaml 软链接
|
||||
if os.path.exists(yaml_link) or os.path.islink(yaml_link):
|
||||
if force_relink:
|
||||
os.remove(yaml_link)
|
||||
os.symlink(yaml_source, yaml_link)
|
||||
linked += 1
|
||||
else:
|
||||
os.symlink(yaml_source, yaml_link)
|
||||
linked += 1
|
||||
|
||||
# 创建计算脚本软链接
|
||||
if os.path.exists(script_link) or os.path.islink(script_link):
|
||||
if force_relink:
|
||||
os.remove(script_link)
|
||||
os.symlink(analyze_script, script_link)
|
||||
else:
|
||||
os.symlink(analyze_script, script_link)
|
||||
|
||||
if count > 0:
|
||||
anion_counts[anion_key] = count
|
||||
total += count
|
||||
print(f" ✓ {anion_key}: {count} 个结构, 配置 -> {yaml_name}")
|
||||
|
||||
print(f"\n 总计: {total} 个结构, 新建软链接: {linked}")
|
||||
|
||||
return WorkspaceInfo(
|
||||
workspace_path=self.workspace_path,
|
||||
data_dir=self.data_dir,
|
||||
tool_dir=self.tool_dir,
|
||||
target_cation=self.target_cation,
|
||||
target_anions=set(anion_counts.keys()),
|
||||
total_structures=total,
|
||||
anion_counts=anion_counts,
|
||||
linked_structures=linked
|
||||
)
|
||||
|
||||
def get_computation_tasks(
|
||||
self,
|
||||
workspace_info: WorkspaceInfo = None
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
获取所有计算任务
|
||||
|
||||
Returns:
|
||||
任务列表,每个任务包含:
|
||||
- cif_path: CIF 文件路径
|
||||
- yaml_name: YAML 配置文件名(如 O.yaml)
|
||||
- work_dir: 工作目录(结构目录)
|
||||
- anion_type: 阴离子类型
|
||||
- structure_name: 结构名称
|
||||
"""
|
||||
if workspace_info is None:
|
||||
workspace_info = self.check_existing_workspace()
|
||||
|
||||
if workspace_info is None:
|
||||
return []
|
||||
|
||||
tasks = []
|
||||
|
||||
for anion_key in workspace_info.anion_counts.keys():
|
||||
anion_dir = os.path.join(self.data_dir, anion_key)
|
||||
yaml_name = f"{anion_key}.yaml"
|
||||
|
||||
# 遍历该阴离子目录下的所有结构
|
||||
for struct_name in os.listdir(anion_dir):
|
||||
struct_dir = os.path.join(anion_dir, struct_name)
|
||||
|
||||
if not os.path.isdir(struct_dir):
|
||||
continue
|
||||
|
||||
# 查找 CIF 文件
|
||||
cif_files = [f for f in os.listdir(struct_dir) if f.endswith('.cif')]
|
||||
|
||||
# 检查是否有 yaml 软链接
|
||||
yaml_path = os.path.join(struct_dir, yaml_name)
|
||||
if not os.path.exists(yaml_path):
|
||||
continue
|
||||
|
||||
for cif_file in cif_files:
|
||||
cif_path = os.path.join(struct_dir, cif_file)
|
||||
|
||||
tasks.append({
|
||||
'cif_path': cif_path,
|
||||
'yaml_name': yaml_name,
|
||||
'work_dir': struct_dir,
|
||||
'anion_type': anion_key,
|
||||
'structure_name': struct_name,
|
||||
'cif_name': cif_file
|
||||
})
|
||||
|
||||
return tasks
|
||||
|
||||
def print_workspace_summary(self, workspace_info: WorkspaceInfo):
|
||||
"""打印工作区摘要"""
|
||||
print("\n" + "=" * 60)
|
||||
print("【工作区摘要】")
|
||||
print("=" * 60)
|
||||
print(f" 工作区路径: {workspace_info.workspace_path}")
|
||||
print(f" 数据目录: {workspace_info.data_dir}")
|
||||
print(f" 目标阳离子: {workspace_info.target_cation}")
|
||||
print(f" 总结构数: {workspace_info.total_structures}")
|
||||
print(f" 已配置软链接: {workspace_info.linked_structures}")
|
||||
print()
|
||||
print(" 阴离子分布:")
|
||||
for anion, count in sorted(workspace_info.anion_counts.items()):
|
||||
print(f" - {anion}: {count} 个结构")
|
||||
print("=" * 60)
|
||||
446
src/computation/zeo_executor.py
Normal file
446
src/computation/zeo_executor.py
Normal file
@@ -0,0 +1,446 @@
|
||||
"""
|
||||
Zeo++ 计算执行器:使用 SLURM 作业数组高效调度大量计算任务
|
||||
"""
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
import json
|
||||
import tempfile
|
||||
from typing import List, Dict, Optional, Callable, Any
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
import threading
|
||||
|
||||
from ..core.progress import ProgressManager
|
||||
|
||||
|
||||
@dataclass
|
||||
class ZeoConfig:
|
||||
"""Zeo++ 计算配置"""
|
||||
# 环境配置
|
||||
conda_env: str = "/cluster/home/koko125/anaconda3/envs/zeo"
|
||||
|
||||
# SLURM 配置
|
||||
partition: str = "cpu"
|
||||
time_limit: str = "2:00:00" # 单个任务时间限制
|
||||
memory_per_task: str = "4G"
|
||||
|
||||
# 作业数组配置
|
||||
max_array_size: int = 1000 # SLURM 作业数组最大大小
|
||||
max_concurrent: int = 50 # 最大并发任务数
|
||||
|
||||
# 轮询配置
|
||||
poll_interval: float = 5.0 # 状态检查间隔(秒)
|
||||
|
||||
# 过滤器配置
|
||||
filters: List[str] = field(default_factory=lambda: [
|
||||
"Ordered", "PropOxi", "VoroPerco", "Coulomb", "VoroBV", "VoroInfo", "MergeSite"
|
||||
])
|
||||
|
||||
|
||||
@dataclass
|
||||
class ZeoTaskResult:
|
||||
"""单个任务结果"""
|
||||
task_id: int
|
||||
structure_name: str
|
||||
cif_path: str
|
||||
success: bool
|
||||
output_files: List[str] = field(default_factory=list)
|
||||
error_message: str = ""
|
||||
duration: float = 0.0
|
||||
|
||||
|
||||
class ZeoExecutor:
|
||||
"""
|
||||
Zeo++ 计算执行器
|
||||
|
||||
使用 SLURM 作业数组高效调度大量 Voronoi 分析任务
|
||||
"""
|
||||
|
||||
def __init__(self, config: ZeoConfig = None):
|
||||
self.config = config or ZeoConfig()
|
||||
self.progress_manager = None
|
||||
self._stop_event = threading.Event()
|
||||
|
||||
def run_batch(
|
||||
self,
|
||||
tasks: List[Dict],
|
||||
output_dir: str = None,
|
||||
desc: str = "Zeo++ 计算"
|
||||
) -> List[ZeoTaskResult]:
|
||||
"""
|
||||
批量执行 Zeo++ 计算
|
||||
|
||||
Args:
|
||||
tasks: 任务列表,每个任务包含 cif_path, yaml_path, work_dir 等
|
||||
output_dir: SLURM 日志输出目录
|
||||
desc: 进度条描述
|
||||
|
||||
Returns:
|
||||
ZeoTaskResult 列表
|
||||
"""
|
||||
if not tasks:
|
||||
print("⚠️ 没有任务需要执行")
|
||||
return []
|
||||
|
||||
total = len(tasks)
|
||||
|
||||
# 创建输出目录
|
||||
if output_dir is None:
|
||||
output_dir = os.path.join(os.getcwd(), "slurm_logs")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"【Zeo++ 批量计算】")
|
||||
print(f"{'='*60}")
|
||||
print(f" 总任务数: {total}")
|
||||
print(f" Conda环境: {self.config.conda_env}")
|
||||
print(f" SLURM分区: {self.config.partition}")
|
||||
print(f" 最大并发: {self.config.max_concurrent}")
|
||||
print(f" 日志目录: {output_dir}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# 保存任务列表到文件
|
||||
tasks_file = os.path.join(output_dir, "tasks.json")
|
||||
with open(tasks_file, 'w') as f:
|
||||
json.dump(tasks, f, indent=2)
|
||||
|
||||
# 生成并提交作业数组
|
||||
if total <= self.config.max_array_size:
|
||||
# 单个作业数组
|
||||
return self._submit_array_job(tasks, output_dir, desc)
|
||||
else:
|
||||
# 分批提交多个作业数组
|
||||
return self._submit_batched_arrays(tasks, output_dir, desc)
|
||||
|
||||
def _submit_array_job(
|
||||
self,
|
||||
tasks: List[Dict],
|
||||
output_dir: str,
|
||||
desc: str
|
||||
) -> List[ZeoTaskResult]:
|
||||
"""提交单个作业数组"""
|
||||
total = len(tasks)
|
||||
|
||||
# 保存任务列表
|
||||
tasks_file = os.path.join(output_dir, "tasks.json")
|
||||
with open(tasks_file, 'w') as f:
|
||||
json.dump(tasks, f, indent=2)
|
||||
|
||||
# 生成作业脚本
|
||||
script_content = self._generate_array_script(
|
||||
tasks_file=tasks_file,
|
||||
output_dir=output_dir,
|
||||
array_range=f"0-{total-1}%{self.config.max_concurrent}"
|
||||
)
|
||||
|
||||
script_path = os.path.join(output_dir, "submit_array.sh")
|
||||
with open(script_path, 'w') as f:
|
||||
f.write(script_content)
|
||||
os.chmod(script_path, 0o755)
|
||||
|
||||
print(f"生成作业脚本: {script_path}")
|
||||
|
||||
# 提交作业
|
||||
result = subprocess.run(
|
||||
['sbatch', script_path],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
print(f"❌ 作业提交失败: {result.stderr}")
|
||||
return [ZeoTaskResult(
|
||||
task_id=i,
|
||||
structure_name=t.get('structure_name', ''),
|
||||
cif_path=t.get('cif_path', ''),
|
||||
success=False,
|
||||
error_message=f"提交失败: {result.stderr}"
|
||||
) for i, t in enumerate(tasks)]
|
||||
|
||||
# 提取作业 ID
|
||||
job_id = result.stdout.strip().split()[-1]
|
||||
print(f"✓ 作业已提交: {job_id}")
|
||||
print(f" 作业数组范围: 0-{total-1}")
|
||||
print(f" 最大并发: {self.config.max_concurrent}")
|
||||
|
||||
# 监控作业进度
|
||||
return self._monitor_array_job(job_id, tasks, output_dir, desc)
|
||||
|
||||
def _submit_batched_arrays(
|
||||
self,
|
||||
tasks: List[Dict],
|
||||
output_dir: str,
|
||||
desc: str
|
||||
) -> List[ZeoTaskResult]:
|
||||
"""分批提交多个作业数组"""
|
||||
total = len(tasks)
|
||||
batch_size = self.config.max_array_size
|
||||
num_batches = (total + batch_size - 1) // batch_size
|
||||
|
||||
print(f"任务数超过作业数组限制,分 {num_batches} 批提交")
|
||||
|
||||
all_results = []
|
||||
|
||||
for batch_idx in range(num_batches):
|
||||
start_idx = batch_idx * batch_size
|
||||
end_idx = min(start_idx + batch_size, total)
|
||||
batch_tasks = tasks[start_idx:end_idx]
|
||||
|
||||
batch_output_dir = os.path.join(output_dir, f"batch_{batch_idx}")
|
||||
os.makedirs(batch_output_dir, exist_ok=True)
|
||||
|
||||
print(f"\n--- 批次 {batch_idx + 1}/{num_batches} ---")
|
||||
print(f"任务范围: {start_idx} - {end_idx - 1}")
|
||||
|
||||
batch_results = self._submit_array_job(
|
||||
batch_tasks,
|
||||
batch_output_dir,
|
||||
f"{desc} (批次 {batch_idx + 1}/{num_batches})"
|
||||
)
|
||||
|
||||
# 调整任务 ID
|
||||
for r in batch_results:
|
||||
r.task_id += start_idx
|
||||
|
||||
all_results.extend(batch_results)
|
||||
|
||||
return all_results
|
||||
|
||||
def _generate_array_script(
|
||||
self,
|
||||
tasks_file: str,
|
||||
output_dir: str,
|
||||
array_range: str
|
||||
) -> str:
|
||||
"""生成 SLURM 作业数组脚本"""
|
||||
|
||||
# 获取项目根目录
|
||||
project_root = os.getcwd()
|
||||
|
||||
script = f"""#!/bin/bash
|
||||
#SBATCH --job-name=zeo_array
|
||||
#SBATCH --partition={self.config.partition}
|
||||
#SBATCH --array={array_range}
|
||||
#SBATCH --ntasks=1
|
||||
#SBATCH --cpus-per-task=1
|
||||
#SBATCH --mem={self.config.memory_per_task}
|
||||
#SBATCH --time={self.config.time_limit}
|
||||
#SBATCH --output={output_dir}/task_%a.out
|
||||
#SBATCH --error={output_dir}/task_%a.err
|
||||
|
||||
# ============================================
|
||||
# Zeo++ Voronoi 分析 - 作业数组
|
||||
# ============================================
|
||||
|
||||
echo "===== 任务信息 ====="
|
||||
echo "作业ID: $SLURM_JOB_ID"
|
||||
echo "数组任务ID: $SLURM_ARRAY_TASK_ID"
|
||||
echo "节点: $SLURM_NODELIST"
|
||||
echo "开始时间: $(date)"
|
||||
echo "===================="
|
||||
|
||||
# ============ 环境初始化 ============
|
||||
# 加载 bashrc
|
||||
if [ -f ~/.bashrc ]; then
|
||||
source ~/.bashrc
|
||||
fi
|
||||
|
||||
# 初始化 Conda
|
||||
if [ -f ~/anaconda3/etc/profile.d/conda.sh ]; then
|
||||
source ~/anaconda3/etc/profile.d/conda.sh
|
||||
elif [ -f /opt/anaconda3/etc/profile.d/conda.sh ]; then
|
||||
source /opt/anaconda3/etc/profile.d/conda.sh
|
||||
fi
|
||||
|
||||
# 激活 Zeo++ 环境
|
||||
conda activate {self.config.conda_env}
|
||||
|
||||
echo ""
|
||||
echo "===== 环境检查 ====="
|
||||
echo "Conda环境: $CONDA_DEFAULT_ENV"
|
||||
echo "Python路径: $(which python)"
|
||||
echo "===================="
|
||||
echo ""
|
||||
|
||||
# ============ 读取任务信息 ============
|
||||
TASKS_FILE="{tasks_file}"
|
||||
TASK_ID=$SLURM_ARRAY_TASK_ID
|
||||
|
||||
# 使用 Python 解析任务
|
||||
TASK_INFO=$(python3 -c "
|
||||
import json
|
||||
with open('$TASKS_FILE', 'r') as f:
|
||||
tasks = json.load(f)
|
||||
if $TASK_ID < len(tasks):
|
||||
task = tasks[$TASK_ID]
|
||||
print(task['work_dir'])
|
||||
print(task['yaml_name'])
|
||||
else:
|
||||
print('ERROR')
|
||||
")
|
||||
|
||||
WORK_DIR=$(echo "$TASK_INFO" | sed -n '1p')
|
||||
YAML_NAME=$(echo "$TASK_INFO" | sed -n '2p')
|
||||
|
||||
if [ "$WORK_DIR" == "ERROR" ]; then
|
||||
echo "错误: 任务ID $TASK_ID 超出范围"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "工作目录: $WORK_DIR"
|
||||
echo "配置文件: $YAML_NAME"
|
||||
echo ""
|
||||
|
||||
# ============ 执行计算 ============
|
||||
cd "$WORK_DIR"
|
||||
|
||||
echo "开始 Voronoi 分析..."
|
||||
# 软链接已在工作目录下,直接使用相对路径
|
||||
# 将输出重定向到 log.txt 以便后续提取结果
|
||||
python analyze_voronoi_nodes.py *.cif -i "$YAML_NAME" > log.txt 2>&1
|
||||
|
||||
EXIT_CODE=$?
|
||||
|
||||
# 显示日志内容(用于调试)
|
||||
echo ""
|
||||
echo "===== 计算日志 ====="
|
||||
cat log.txt
|
||||
echo "===================="
|
||||
|
||||
# ============ 完成 ============
|
||||
echo ""
|
||||
echo "===== 任务完成 ====="
|
||||
echo "结束时间: $(date)"
|
||||
echo "退出代码: $EXIT_CODE"
|
||||
|
||||
# 写入状态文件
|
||||
if [ $EXIT_CODE -eq 0 ]; then
|
||||
echo "SUCCESS" > "{output_dir}/status_$TASK_ID.txt"
|
||||
else
|
||||
echo "FAILED" > "{output_dir}/status_$TASK_ID.txt"
|
||||
fi
|
||||
|
||||
echo "===================="
|
||||
exit $EXIT_CODE
|
||||
"""
|
||||
return script
|
||||
|
||||
def _monitor_array_job(
|
||||
self,
|
||||
job_id: str,
|
||||
tasks: List[Dict],
|
||||
output_dir: str,
|
||||
desc: str
|
||||
) -> List[ZeoTaskResult]:
|
||||
"""监控作业数组进度"""
|
||||
total = len(tasks)
|
||||
|
||||
self.progress_manager = ProgressManager(total, desc)
|
||||
self.progress_manager.start()
|
||||
|
||||
results = [None] * total
|
||||
completed = set()
|
||||
|
||||
print(f"\n监控作业进度 (每 {self.config.poll_interval} 秒检查一次)...")
|
||||
print("按 Ctrl+C 可中断监控(作业将继续在后台运行)\n")
|
||||
|
||||
try:
|
||||
while len(completed) < total:
|
||||
time.sleep(self.config.poll_interval)
|
||||
|
||||
# 检查状态文件
|
||||
for i in range(total):
|
||||
if i in completed:
|
||||
continue
|
||||
|
||||
status_file = os.path.join(output_dir, f"status_{i}.txt")
|
||||
if os.path.exists(status_file):
|
||||
with open(status_file, 'r') as f:
|
||||
status = f.read().strip()
|
||||
|
||||
task = tasks[i]
|
||||
success = (status == "SUCCESS")
|
||||
|
||||
# 收集输出文件
|
||||
output_files = []
|
||||
if success:
|
||||
work_dir = task['work_dir']
|
||||
for f in os.listdir(work_dir):
|
||||
if f.endswith(('.cif', '.csv')) and f != task['cif_name']:
|
||||
output_files.append(os.path.join(work_dir, f))
|
||||
|
||||
results[i] = ZeoTaskResult(
|
||||
task_id=i,
|
||||
structure_name=task.get('structure_name', ''),
|
||||
cif_path=task.get('cif_path', ''),
|
||||
success=success,
|
||||
output_files=output_files
|
||||
)
|
||||
|
||||
completed.add(i)
|
||||
self.progress_manager.update(success=success)
|
||||
self.progress_manager.display()
|
||||
|
||||
# 检查作业是否还在运行
|
||||
if not self._is_job_running(job_id) and len(completed) < total:
|
||||
# 作业已结束但有任务未完成
|
||||
print(f"\n⚠️ 作业已结束,但有 {total - len(completed)} 个任务未完成")
|
||||
break
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n⚠️ 监控已中断,作业将继续在后台运行")
|
||||
print(f" 可使用 'squeue -j {job_id}' 查看作业状态")
|
||||
print(f" 可使用 'scancel {job_id}' 取消作业")
|
||||
|
||||
self.progress_manager.finish()
|
||||
|
||||
# 填充未完成的任务
|
||||
for i in range(total):
|
||||
if results[i] is None:
|
||||
task = tasks[i]
|
||||
results[i] = ZeoTaskResult(
|
||||
task_id=i,
|
||||
structure_name=task.get('structure_name', ''),
|
||||
cif_path=task.get('cif_path', ''),
|
||||
success=False,
|
||||
error_message="任务未完成或状态未知"
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def _is_job_running(self, job_id: str) -> bool:
|
||||
"""检查作业是否还在运行"""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
['squeue', '-j', job_id, '-h'],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10
|
||||
)
|
||||
return bool(result.stdout.strip())
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def print_results_summary(self, results: List[ZeoTaskResult]):
|
||||
"""打印结果摘要"""
|
||||
total = len(results)
|
||||
success = sum(1 for r in results if r.success)
|
||||
failed = total - success
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("【计算结果摘要】")
|
||||
print("=" * 60)
|
||||
print(f" 总任务数: {total}")
|
||||
print(f" 成功: {success} ({100*success/total:.1f}%)")
|
||||
print(f" 失败: {failed} ({100*failed/total:.1f}%)")
|
||||
|
||||
if failed > 0 and failed <= 10:
|
||||
print("\n 失败的任务:")
|
||||
for r in results:
|
||||
if not r.success:
|
||||
print(f" - {r.structure_name}: {r.error_message}")
|
||||
elif failed > 10:
|
||||
print(f"\n 失败任务过多,请检查日志文件")
|
||||
|
||||
print("=" * 60)
|
||||
18
src/core/__init__.py
Normal file
18
src/core/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""
|
||||
核心模块:调度器、执行器和进度管理
|
||||
"""
|
||||
from .scheduler import ParallelScheduler, ResourceConfig, ExecutionMode as SchedulerMode
|
||||
from .executor import TaskExecutor, ExecutorConfig, ExecutionMode, TaskResult, create_executor
|
||||
from .progress import ProgressManager
|
||||
|
||||
__all__ = [
|
||||
'ParallelScheduler',
|
||||
'ResourceConfig',
|
||||
'SchedulerMode',
|
||||
'TaskExecutor',
|
||||
'ExecutorConfig',
|
||||
'ExecutionMode',
|
||||
'TaskResult',
|
||||
'create_executor',
|
||||
'ProgressManager',
|
||||
]
|
||||
0
src/core/controller.py
Normal file
0
src/core/controller.py
Normal file
431
src/core/executor.py
Normal file
431
src/core/executor.py
Normal file
@@ -0,0 +1,431 @@
|
||||
"""
|
||||
任务执行器:支持本地执行和 SLURM 直接提交
|
||||
不生成脚本文件,直接在 Python 中管理任务
|
||||
"""
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
import json
|
||||
from typing import List, Callable, Any, Optional, Dict, Tuple
|
||||
from multiprocessing import Pool, cpu_count
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import threading
|
||||
|
||||
from .progress import ProgressManager
|
||||
|
||||
|
||||
class ExecutionMode(Enum):
|
||||
"""执行模式"""
|
||||
LOCAL = "local" # 本地多进程
|
||||
SLURM_DIRECT = "slurm" # SLURM 直接提交(不生成脚本)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutorConfig:
|
||||
"""执行器配置"""
|
||||
mode: ExecutionMode = ExecutionMode.LOCAL
|
||||
max_workers: int = 4
|
||||
conda_env: str = "/cluster/home/koko125/anaconda3/envs/screen"
|
||||
partition: str = "cpu"
|
||||
time_limit: str = "7-00:00:00"
|
||||
memory_per_task: str = "4G"
|
||||
|
||||
# SLURM 相关
|
||||
poll_interval: float = 2.0 # 轮询间隔(秒)
|
||||
max_concurrent_jobs: int = 50 # 最大并发作业数
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskResult:
|
||||
"""任务结果"""
|
||||
task_id: Any
|
||||
success: bool
|
||||
result: Any = None
|
||||
error: str = None
|
||||
duration: float = 0.0
|
||||
|
||||
|
||||
class TaskExecutor:
|
||||
"""
|
||||
任务执行器
|
||||
|
||||
支持两种模式:
|
||||
1. LOCAL: 本地多进程执行
|
||||
2. SLURM_DIRECT: 直接提交 SLURM 作业,实时监控进度
|
||||
"""
|
||||
|
||||
def __init__(self, config: ExecutorConfig = None):
|
||||
self.config = config or ExecutorConfig()
|
||||
self.progress_manager = None
|
||||
self._stop_event = threading.Event()
|
||||
|
||||
@staticmethod
|
||||
def detect_environment() -> Dict[str, Any]:
|
||||
"""检测运行环境"""
|
||||
env_info = {
|
||||
'hostname': os.uname().nodename,
|
||||
'total_cores': cpu_count(),
|
||||
'has_slurm': False,
|
||||
'slurm_partitions': [],
|
||||
'conda_env': os.environ.get('CONDA_PREFIX', ''),
|
||||
}
|
||||
|
||||
# 检测 SLURM
|
||||
try:
|
||||
result = subprocess.run(
|
||||
['sinfo', '-h', '-o', '%P %a %c %D'],
|
||||
capture_output=True, text=True, timeout=5
|
||||
)
|
||||
if result.returncode == 0:
|
||||
env_info['has_slurm'] = True
|
||||
lines = result.stdout.strip().split('\n')
|
||||
for line in lines:
|
||||
parts = line.split()
|
||||
if len(parts) >= 4:
|
||||
partition = parts[0].rstrip('*')
|
||||
avail = parts[1]
|
||||
if avail == 'up':
|
||||
env_info['slurm_partitions'].append(partition)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return env_info
|
||||
|
||||
def run(
|
||||
self,
|
||||
tasks: List[Any],
|
||||
worker_func: Callable,
|
||||
desc: str = "Processing"
|
||||
) -> List[TaskResult]:
|
||||
"""
|
||||
执行任务
|
||||
|
||||
Args:
|
||||
tasks: 任务列表
|
||||
worker_func: 工作函数,接收单个任务,返回结果
|
||||
desc: 进度条描述
|
||||
|
||||
Returns:
|
||||
TaskResult 列表
|
||||
"""
|
||||
if self.config.mode == ExecutionMode.LOCAL:
|
||||
return self._run_local(tasks, worker_func, desc)
|
||||
elif self.config.mode == ExecutionMode.SLURM_DIRECT:
|
||||
return self._run_slurm_direct(tasks, worker_func, desc)
|
||||
else:
|
||||
raise ValueError(f"不支持的执行模式: {self.config.mode}")
|
||||
|
||||
def _run_local(
|
||||
self,
|
||||
tasks: List[Any],
|
||||
worker_func: Callable,
|
||||
desc: str
|
||||
) -> List[TaskResult]:
|
||||
"""本地多进程执行"""
|
||||
total = len(tasks)
|
||||
num_workers = min(self.config.max_workers, total)
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"本地执行配置:")
|
||||
print(f" 总任务数: {total}")
|
||||
print(f" Worker数: {num_workers}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
self.progress_manager = ProgressManager(total, desc)
|
||||
self.progress_manager.start()
|
||||
|
||||
results = []
|
||||
|
||||
if num_workers == 1:
|
||||
# 单进程执行
|
||||
for i, task in enumerate(tasks):
|
||||
start_time = time.time()
|
||||
try:
|
||||
result = worker_func(task)
|
||||
duration = time.time() - start_time
|
||||
results.append(TaskResult(
|
||||
task_id=i,
|
||||
success=True,
|
||||
result=result,
|
||||
duration=duration
|
||||
))
|
||||
self.progress_manager.update(success=True)
|
||||
except Exception as e:
|
||||
duration = time.time() - start_time
|
||||
results.append(TaskResult(
|
||||
task_id=i,
|
||||
success=False,
|
||||
error=str(e),
|
||||
duration=duration
|
||||
))
|
||||
self.progress_manager.update(success=False)
|
||||
self.progress_manager.display()
|
||||
else:
|
||||
# 多进程执行
|
||||
with Pool(processes=num_workers) as pool:
|
||||
for i, result in enumerate(pool.imap_unordered(worker_func, tasks)):
|
||||
if result is not None:
|
||||
results.append(TaskResult(
|
||||
task_id=i,
|
||||
success=True,
|
||||
result=result
|
||||
))
|
||||
self.progress_manager.update(success=True)
|
||||
else:
|
||||
results.append(TaskResult(
|
||||
task_id=i,
|
||||
success=False,
|
||||
error="Worker returned None"
|
||||
))
|
||||
self.progress_manager.update(success=False)
|
||||
self.progress_manager.display()
|
||||
|
||||
self.progress_manager.finish()
|
||||
return results
|
||||
|
||||
def _run_slurm_direct(
|
||||
self,
|
||||
tasks: List[Any],
|
||||
worker_func: Callable,
|
||||
desc: str
|
||||
) -> List[TaskResult]:
|
||||
"""
|
||||
SLURM 直接提交模式
|
||||
|
||||
注意:对于数据库分析这类快速任务,建议使用本地多进程模式
|
||||
SLURM 模式更适合耗时的计算任务(如 Zeo++ 分析)
|
||||
|
||||
这里回退到本地模式,因为 srun 在登录节点直接调用效率不高
|
||||
"""
|
||||
print("\n⚠️ 注意:数据库分析阶段自动使用本地多进程模式")
|
||||
print(" SLURM 模式将在后续耗时计算步骤中使用")
|
||||
|
||||
# 回退到本地模式
|
||||
return self._run_local(tasks, worker_func, desc)
|
||||
|
||||
|
||||
class SlurmJobManager:
|
||||
"""
|
||||
SLURM 作业管理器
|
||||
|
||||
用于批量提交和监控 SLURM 作业
|
||||
"""
|
||||
|
||||
def __init__(self, config: ExecutorConfig):
|
||||
self.config = config
|
||||
self.active_jobs = {} # job_id -> task_info
|
||||
|
||||
def submit_batch(
|
||||
self,
|
||||
tasks: List[Tuple[str, str, set]], # (file_path, target_cation, target_anions)
|
||||
output_dir: str,
|
||||
desc: str = "Processing"
|
||||
) -> List[TaskResult]:
|
||||
"""
|
||||
批量提交任务到 SLURM
|
||||
|
||||
使用 sbatch --wrap 直接提交,不生成脚本文件
|
||||
"""
|
||||
total = len(tasks)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"SLURM 批量提交:")
|
||||
print(f" 总任务数: {total}")
|
||||
print(f" 输出目录: {output_dir}")
|
||||
print(f" Conda环境: {self.config.conda_env}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
progress = ProgressManager(total, desc)
|
||||
progress.start()
|
||||
|
||||
results = []
|
||||
job_ids = []
|
||||
|
||||
# 提交所有任务
|
||||
for i, task in enumerate(tasks):
|
||||
file_path, target_cation, target_anions = task
|
||||
|
||||
# 构建 Python 命令
|
||||
anions_str = ','.join(target_anions)
|
||||
python_cmd = (
|
||||
f"python -c \""
|
||||
f"import sys; sys.path.insert(0, '{os.getcwd()}'); "
|
||||
f"from src.analysis.worker import analyze_single_file; "
|
||||
f"result = analyze_single_file(('{file_path}', '{target_cation}', set('{anions_str}'.split(',')))); "
|
||||
f"print('SUCCESS' if result and result.is_valid else 'FAILED')"
|
||||
f"\""
|
||||
)
|
||||
|
||||
# 构建完整的 bash 命令
|
||||
bash_cmd = (
|
||||
f"source {os.path.dirname(self.config.conda_env)}/../../etc/profile.d/conda.sh && "
|
||||
f"conda activate {self.config.conda_env} && "
|
||||
f"{python_cmd}"
|
||||
)
|
||||
|
||||
# 使用 sbatch --wrap 提交
|
||||
sbatch_cmd = [
|
||||
'sbatch',
|
||||
'--partition', self.config.partition,
|
||||
'--ntasks', '1',
|
||||
'--cpus-per-task', '1',
|
||||
'--mem', self.config.memory_per_task,
|
||||
'--time', '01:00:00',
|
||||
'--output', os.path.join(output_dir, f'task_{i}.out'),
|
||||
'--error', os.path.join(output_dir, f'task_{i}.err'),
|
||||
'--wrap', bash_cmd
|
||||
]
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
sbatch_cmd,
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
# 提取 job_id
|
||||
job_id = result.stdout.strip().split()[-1]
|
||||
job_ids.append((i, job_id, file_path))
|
||||
self.active_jobs[job_id] = {
|
||||
'task_index': i,
|
||||
'file_path': file_path,
|
||||
'status': 'PENDING'
|
||||
}
|
||||
else:
|
||||
results.append(TaskResult(
|
||||
task_id=i,
|
||||
success=False,
|
||||
error=f"提交失败: {result.stderr}"
|
||||
))
|
||||
progress.update(success=False)
|
||||
progress.display()
|
||||
|
||||
except Exception as e:
|
||||
results.append(TaskResult(
|
||||
task_id=i,
|
||||
success=False,
|
||||
error=str(e)
|
||||
))
|
||||
progress.update(success=False)
|
||||
progress.display()
|
||||
|
||||
print(f"\n已提交 {len(job_ids)} 个作业,等待完成...")
|
||||
|
||||
# 监控作业状态
|
||||
while self.active_jobs:
|
||||
time.sleep(self.config.poll_interval)
|
||||
|
||||
# 检查作业状态
|
||||
completed_jobs = self._check_job_status()
|
||||
|
||||
for job_id, status in completed_jobs:
|
||||
job_info = self.active_jobs.pop(job_id, None)
|
||||
if job_info:
|
||||
task_idx = job_info['task_index']
|
||||
|
||||
if status == 'COMPLETED':
|
||||
# 检查输出文件
|
||||
out_file = os.path.join(output_dir, f'task_{task_idx}.out')
|
||||
success = False
|
||||
if os.path.exists(out_file):
|
||||
with open(out_file, 'r') as f:
|
||||
content = f.read()
|
||||
success = 'SUCCESS' in content
|
||||
|
||||
results.append(TaskResult(
|
||||
task_id=task_idx,
|
||||
success=success,
|
||||
result=job_info['file_path']
|
||||
))
|
||||
progress.update(success=success)
|
||||
else:
|
||||
# 作业失败
|
||||
err_file = os.path.join(output_dir, f'task_{task_idx}.err')
|
||||
error_msg = status
|
||||
if os.path.exists(err_file):
|
||||
with open(err_file, 'r') as f:
|
||||
error_msg = f.read()[:500] # 只取前500字符
|
||||
|
||||
results.append(TaskResult(
|
||||
task_id=task_idx,
|
||||
success=False,
|
||||
error=error_msg
|
||||
))
|
||||
progress.update(success=False)
|
||||
|
||||
progress.display()
|
||||
|
||||
progress.finish()
|
||||
return results
|
||||
|
||||
def _check_job_status(self) -> List[Tuple[str, str]]:
|
||||
"""检查作业状态,返回已完成的作业列表"""
|
||||
if not self.active_jobs:
|
||||
return []
|
||||
|
||||
job_ids = list(self.active_jobs.keys())
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
['sacct', '-j', ','.join(job_ids), '--format=JobID,State', '--noheader', '--parsable2'],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30
|
||||
)
|
||||
|
||||
completed = []
|
||||
if result.returncode == 0:
|
||||
for line in result.stdout.strip().split('\n'):
|
||||
if line:
|
||||
parts = line.split('|')
|
||||
if len(parts) >= 2:
|
||||
job_id = parts[0].split('.')[0] # 去掉 .batch 后缀
|
||||
status = parts[1]
|
||||
|
||||
if job_id in self.active_jobs:
|
||||
if status in ['COMPLETED', 'FAILED', 'CANCELLED', 'TIMEOUT', 'NODE_FAIL']:
|
||||
completed.append((job_id, status))
|
||||
|
||||
return completed
|
||||
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
|
||||
def create_executor(
|
||||
mode: str = "local",
|
||||
max_workers: int = None,
|
||||
conda_env: str = None,
|
||||
**kwargs
|
||||
) -> TaskExecutor:
|
||||
"""
|
||||
创建任务执行器的便捷函数
|
||||
|
||||
Args:
|
||||
mode: "local" 或 "slurm"
|
||||
max_workers: 最大工作进程数
|
||||
conda_env: Conda 环境路径
|
||||
**kwargs: 其他配置参数
|
||||
"""
|
||||
env = TaskExecutor.detect_environment()
|
||||
|
||||
if max_workers is None:
|
||||
max_workers = min(env['total_cores'], 32)
|
||||
|
||||
if conda_env is None:
|
||||
conda_env = env.get('conda_env') or "/cluster/home/koko125/anaconda3/envs/screen"
|
||||
|
||||
exec_mode = ExecutionMode.SLURM_DIRECT if mode.lower() == "slurm" else ExecutionMode.LOCAL
|
||||
|
||||
config = ExecutorConfig(
|
||||
mode=exec_mode,
|
||||
max_workers=max_workers,
|
||||
conda_env=conda_env,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
return TaskExecutor(config)
|
||||
115
src/core/progress.py
Normal file
115
src/core/progress.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""
|
||||
进度管理器:支持多进程的实时进度显示
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from multiprocessing import Manager, Value
|
||||
from typing import Optional
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
class ProgressManager:
|
||||
"""多进程安全的进度管理器"""
|
||||
|
||||
def __init__(self, total: int, desc: str = "Processing"):
|
||||
self.total = total
|
||||
self.desc = desc
|
||||
self.manager = Manager()
|
||||
self.completed = self.manager.Value('i', 0)
|
||||
self.failed = self.manager.Value('i', 0)
|
||||
self.start_time = None
|
||||
self._lock = self.manager.Lock()
|
||||
|
||||
def start(self):
|
||||
"""开始计时"""
|
||||
self.start_time = time.time()
|
||||
|
||||
def update(self, success: bool = True):
|
||||
"""更新进度(进程安全)"""
|
||||
with self._lock:
|
||||
if success:
|
||||
self.completed.value += 1
|
||||
else:
|
||||
self.failed.value += 1
|
||||
|
||||
def get_progress(self) -> dict:
|
||||
"""获取当前进度"""
|
||||
completed = self.completed.value
|
||||
failed = self.failed.value
|
||||
total_done = completed + failed
|
||||
|
||||
elapsed = time.time() - self.start_time if self.start_time else 0
|
||||
|
||||
if total_done > 0:
|
||||
speed = total_done / elapsed # items/sec
|
||||
remaining = (self.total - total_done) / speed if speed > 0 else 0
|
||||
else:
|
||||
speed = 0
|
||||
remaining = 0
|
||||
|
||||
return {
|
||||
'total': self.total,
|
||||
'completed': completed,
|
||||
'failed': failed,
|
||||
'total_done': total_done,
|
||||
'percent': total_done / self.total * 100 if self.total > 0 else 0,
|
||||
'elapsed': elapsed,
|
||||
'remaining': remaining,
|
||||
'speed': speed
|
||||
}
|
||||
|
||||
def display(self):
|
||||
"""显示进度条"""
|
||||
p = self.get_progress()
|
||||
|
||||
# 进度条
|
||||
bar_width = 40
|
||||
filled = int(bar_width * p['total_done'] / p['total']) if p['total'] > 0 else 0
|
||||
bar = '█' * filled + '░' * (bar_width - filled)
|
||||
|
||||
# 时间格式化
|
||||
elapsed_str = str(timedelta(seconds=int(p['elapsed'])))
|
||||
remaining_str = str(timedelta(seconds=int(p['remaining'])))
|
||||
|
||||
# 构建显示字符串
|
||||
status = (
|
||||
f"\r{self.desc}: |{bar}| "
|
||||
f"{p['total_done']}/{p['total']} ({p['percent']:.1f}%) "
|
||||
f"[{elapsed_str}<{remaining_str}, {p['speed']:.1f}it/s] "
|
||||
f"✓{p['completed']} ✗{p['failed']}"
|
||||
)
|
||||
|
||||
sys.stdout.write(status)
|
||||
sys.stdout.flush()
|
||||
|
||||
def finish(self):
|
||||
"""完成显示"""
|
||||
self.display()
|
||||
print() # 换行
|
||||
|
||||
|
||||
class ProgressReporter:
|
||||
"""进度报告器(用于后台监控)"""
|
||||
|
||||
def __init__(self, progress_file: str = ".progress"):
|
||||
self.progress_file = progress_file
|
||||
|
||||
def save(self, progress: dict):
|
||||
"""保存进度到文件"""
|
||||
import json
|
||||
with open(self.progress_file, 'w') as f:
|
||||
json.dump(progress, f)
|
||||
|
||||
def load(self) -> Optional[dict]:
|
||||
"""从文件加载进度"""
|
||||
import json
|
||||
if os.path.exists(self.progress_file):
|
||||
with open(self.progress_file, 'r') as f:
|
||||
return json.load(f)
|
||||
return None
|
||||
|
||||
def cleanup(self):
|
||||
"""清理进度文件"""
|
||||
if os.path.exists(self.progress_file):
|
||||
os.remove(self.progress_file)
|
||||
291
src/core/scheduler.py
Normal file
291
src/core/scheduler.py
Normal file
@@ -0,0 +1,291 @@
|
||||
"""
|
||||
并行调度器:支持本地多进程和SLURM集群调度
|
||||
"""
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
from typing import List, Callable, Any, Optional, Dict
|
||||
from multiprocessing import Pool, cpu_count
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
import math
|
||||
|
||||
from .progress import ProgressManager
|
||||
|
||||
|
||||
class ExecutionMode(Enum):
|
||||
"""执行模式"""
|
||||
LOCAL_SINGLE = "local_single"
|
||||
LOCAL_MULTI = "local_multi"
|
||||
SLURM_SINGLE = "slurm_single"
|
||||
SLURM_ARRAY = "slurm_array"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResourceConfig:
|
||||
"""资源配置"""
|
||||
max_cores: int = 4
|
||||
cores_per_worker: int = 1
|
||||
memory_per_core: str = "4G"
|
||||
partition: str = "cpu"
|
||||
time_limit: str = "7-00:00:00"
|
||||
conda_env_path: str = "~/anaconda3/envs/screen" # 新增:Conda环境路径
|
||||
|
||||
@property
|
||||
def num_workers(self) -> int:
|
||||
return max(1, self.max_cores // self.cores_per_worker)
|
||||
|
||||
|
||||
class ParallelScheduler:
|
||||
"""并行调度器"""
|
||||
|
||||
COMPLEXITY_CORES = {
|
||||
'low': 1,
|
||||
'medium': 2,
|
||||
'high': 4,
|
||||
}
|
||||
|
||||
def __init__(self, resource_config: ResourceConfig = None):
|
||||
self.config = resource_config or ResourceConfig()
|
||||
self.progress_manager = None
|
||||
|
||||
@staticmethod
|
||||
def detect_environment() -> Dict[str, Any]:
|
||||
"""检测运行环境"""
|
||||
env_info = {
|
||||
'hostname': os.uname().nodename,
|
||||
'total_cores': cpu_count(),
|
||||
'has_slurm': False,
|
||||
'slurm_partitions': [],
|
||||
'available_nodes': 0,
|
||||
'conda_prefix': os.environ.get('CONDA_PREFIX', ''),
|
||||
}
|
||||
|
||||
# 检测SLURM
|
||||
try:
|
||||
result = subprocess.run(
|
||||
['sinfo', '-h', '-o', '%P %a %c %D'],
|
||||
capture_output=True, text=True, timeout=5
|
||||
)
|
||||
if result.returncode == 0:
|
||||
env_info['has_slurm'] = True
|
||||
lines = result.stdout.strip().split('\n')
|
||||
for line in lines:
|
||||
parts = line.split()
|
||||
if len(parts) >= 4:
|
||||
partition = parts[0].rstrip('*')
|
||||
avail = parts[1]
|
||||
cores = int(parts[2])
|
||||
nodes = int(parts[3])
|
||||
if avail == 'up':
|
||||
env_info['slurm_partitions'].append({
|
||||
'name': partition,
|
||||
'cores_per_node': cores,
|
||||
'nodes': nodes
|
||||
})
|
||||
env_info['available_nodes'] += nodes
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return env_info
|
||||
|
||||
@staticmethod
|
||||
def recommend_config(
|
||||
num_tasks: int,
|
||||
task_complexity: str = 'medium',
|
||||
max_cores: int = None,
|
||||
conda_env_path: str = None
|
||||
) -> ResourceConfig:
|
||||
"""根据任务量和复杂度推荐配置"""
|
||||
|
||||
env = ParallelScheduler.detect_environment()
|
||||
|
||||
if max_cores is None:
|
||||
max_cores = min(env['total_cores'], 32)
|
||||
|
||||
cores_per_worker = ParallelScheduler.COMPLEXITY_CORES.get(task_complexity, 2)
|
||||
max_workers = max_cores // cores_per_worker
|
||||
optimal_workers = min(num_tasks, max_workers)
|
||||
actual_cores = optimal_workers * cores_per_worker
|
||||
|
||||
# 自动检测Conda环境路径
|
||||
if conda_env_path is None:
|
||||
conda_env_path = env.get('conda_prefix', '~/anaconda3/envs/screen')
|
||||
|
||||
config = ResourceConfig(
|
||||
max_cores=actual_cores,
|
||||
cores_per_worker=cores_per_worker,
|
||||
conda_env_path=conda_env_path,
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
def run_local(
|
||||
self,
|
||||
tasks: List[Any],
|
||||
worker_func: Callable,
|
||||
desc: str = "Processing"
|
||||
) -> List[Any]:
|
||||
"""本地多进程执行"""
|
||||
|
||||
num_workers = self.config.num_workers
|
||||
total = len(tasks)
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"并行配置:")
|
||||
print(f" 总任务数: {total}")
|
||||
print(f" Worker数: {num_workers}")
|
||||
print(f" 每Worker核数: {self.config.cores_per_worker}")
|
||||
print(f" 总使用核数: {self.config.max_cores}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
self.progress_manager = ProgressManager(total, desc)
|
||||
self.progress_manager.start()
|
||||
|
||||
results = []
|
||||
|
||||
if num_workers == 1:
|
||||
for task in tasks:
|
||||
try:
|
||||
result = worker_func(task)
|
||||
results.append(result)
|
||||
self.progress_manager.update(success=True)
|
||||
except Exception as e:
|
||||
results.append(None)
|
||||
self.progress_manager.update(success=False)
|
||||
self.progress_manager.display()
|
||||
else:
|
||||
with Pool(processes=num_workers) as pool:
|
||||
for result in pool.imap_unordered(worker_func, tasks, chunksize=10):
|
||||
if result is not None:
|
||||
results.append(result)
|
||||
self.progress_manager.update(success=True)
|
||||
else:
|
||||
self.progress_manager.update(success=False)
|
||||
self.progress_manager.display()
|
||||
|
||||
self.progress_manager.finish()
|
||||
|
||||
return results
|
||||
|
||||
def generate_slurm_script(
|
||||
self,
|
||||
tasks_file: str,
|
||||
worker_script: str,
|
||||
output_dir: str,
|
||||
job_name: str = "analysis"
|
||||
) -> str:
|
||||
"""生成SLURM作业脚本(修复版)"""
|
||||
|
||||
# 获取当前工作目录的绝对路径
|
||||
submit_dir = os.getcwd()
|
||||
|
||||
# 转换为绝对路径
|
||||
abs_tasks_file = os.path.abspath(tasks_file)
|
||||
abs_worker_script = os.path.abspath(worker_script)
|
||||
abs_output_dir = os.path.abspath(output_dir)
|
||||
|
||||
script = f"""#!/bin/bash
|
||||
#SBATCH --job-name={job_name}
|
||||
#SBATCH --partition={self.config.partition}
|
||||
#SBATCH --nodes=1
|
||||
#SBATCH --ntasks=1
|
||||
#SBATCH --cpus-per-task={self.config.max_cores}
|
||||
#SBATCH --mem-per-cpu={self.config.memory_per_core}
|
||||
#SBATCH --time={self.config.time_limit}
|
||||
#SBATCH --output={abs_output_dir}/slurm_%j.log
|
||||
#SBATCH --error={abs_output_dir}/slurm_%j.err
|
||||
|
||||
# ============================================
|
||||
# SLURM作业脚本 - 自动生成
|
||||
# ============================================
|
||||
|
||||
echo "===== 作业信息 ====="
|
||||
echo "作业ID: $SLURM_JOB_ID"
|
||||
echo "节点: $SLURM_NODELIST"
|
||||
echo "CPU数: $SLURM_CPUS_PER_TASK"
|
||||
echo "开始时间: $(date)"
|
||||
echo "===================="
|
||||
|
||||
# ============ 环境初始化 ============
|
||||
# 关键:确保bashrc被加载
|
||||
if [ -f ~/.bashrc ]; then
|
||||
source ~/.bashrc
|
||||
fi
|
||||
|
||||
# 初始化Conda
|
||||
if [ -f ~/anaconda3/etc/profile.d/conda.sh ]; then
|
||||
source ~/anaconda3/etc/profile.d/conda.sh
|
||||
elif [ -f /opt/anaconda3/etc/profile.d/conda.sh ]; then
|
||||
source /opt/anaconda3/etc/profile.d/conda.sh
|
||||
else
|
||||
echo "错误: 找不到conda.sh"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 激活环境 (使用完整路径)
|
||||
conda activate {self.config.conda_env_path}
|
||||
|
||||
# 验证环境
|
||||
echo ""
|
||||
echo "===== 环境检查 ====="
|
||||
echo "Conda环境: $CONDA_DEFAULT_ENV"
|
||||
echo "Python路径: $(which python)"
|
||||
echo "Python版本: $(python --version 2>&1)"
|
||||
python -c "import pymatgen; print(f'pymatgen版本: {{pymatgen.__version__}}')" 2>/dev/null || echo "警告: pymatgen未安装"
|
||||
echo "===================="
|
||||
echo ""
|
||||
|
||||
# 设置工作目录
|
||||
cd {submit_dir}
|
||||
export PYTHONPATH={submit_dir}:$PYTHONPATH
|
||||
|
||||
echo "工作目录: $(pwd)"
|
||||
echo "PYTHONPATH: $PYTHONPATH"
|
||||
echo ""
|
||||
|
||||
# ============ 运行分析 ============
|
||||
echo "开始执行分析任务..."
|
||||
python {abs_worker_script} \\
|
||||
--tasks-file {abs_tasks_file} \\
|
||||
--output-dir {abs_output_dir} \\
|
||||
--num-workers {self.config.num_workers}
|
||||
|
||||
EXIT_CODE=$?
|
||||
|
||||
# ============ 完成 ============
|
||||
echo ""
|
||||
echo "===== 作业完成 ====="
|
||||
echo "结束时间: $(date)"
|
||||
echo "退出代码: $EXIT_CODE"
|
||||
echo "===================="
|
||||
|
||||
exit $EXIT_CODE
|
||||
"""
|
||||
return script
|
||||
|
||||
def submit_slurm_job(self, script_content: str, script_path: str = None) -> str:
|
||||
"""提交SLURM作业"""
|
||||
|
||||
if script_path is None:
|
||||
fd, script_path = tempfile.mkstemp(suffix='.sh')
|
||||
os.close(fd)
|
||||
|
||||
with open(script_path, 'w') as f:
|
||||
f.write(script_content)
|
||||
|
||||
os.chmod(script_path, 0o755)
|
||||
|
||||
# 打印脚本内容用于调试
|
||||
print(f"\n生成的SLURM脚本保存到: {script_path}")
|
||||
|
||||
result = subprocess.run(
|
||||
['sbatch', script_path],
|
||||
capture_output=True, text=True
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
job_id = result.stdout.strip().split()[-1]
|
||||
return job_id
|
||||
else:
|
||||
raise RuntimeError(f"SLURM提交失败:\nstdout: {result.stdout}\nstderr: {result.stderr}")
|
||||
0
src/preprocessing/__init__.py
Normal file
0
src/preprocessing/__init__.py
Normal file
0
src/preprocessing/classifier.py
Normal file
0
src/preprocessing/classifier.py
Normal file
0
src/preprocessing/cleaner.py
Normal file
0
src/preprocessing/cleaner.py
Normal file
562
src/preprocessing/processor.py
Normal file
562
src/preprocessing/processor.py
Normal file
@@ -0,0 +1,562 @@
|
||||
"""
|
||||
结构预处理器:扩胞和添加化合价
|
||||
"""
|
||||
import os
|
||||
import re
|
||||
import yaml
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from pymatgen.core.structure import Structure
|
||||
from pymatgen.core.periodic_table import Specie
|
||||
from pymatgen.core import Lattice, Species, PeriodicSite
|
||||
from collections import defaultdict
|
||||
from fractions import Fraction
|
||||
from functools import reduce
|
||||
import math
|
||||
import random
|
||||
import spglib
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessingResult:
|
||||
"""处理结果"""
|
||||
input_file: str
|
||||
output_files: List[str] = field(default_factory=list)
|
||||
success: bool = False
|
||||
needs_expansion: bool = False
|
||||
expansion_factor: int = 1
|
||||
error_message: str = ""
|
||||
|
||||
|
||||
class StructureProcessor:
|
||||
"""结构预处理器"""
|
||||
|
||||
# 默认化合价配置
|
||||
DEFAULT_VALENCE_PATH = os.path.join(
|
||||
os.path.dirname(__file__), '..', '..', 'tool', 'valence_states.yaml'
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
valence_yaml_path: str = None,
|
||||
calculate_type: str = 'low',
|
||||
max_expansion_factor: int = 64,
|
||||
keep_number: int = 3,
|
||||
target_cation: str = "Li"
|
||||
):
|
||||
"""
|
||||
初始化处理器
|
||||
|
||||
Args:
|
||||
valence_yaml_path: 化合价配置文件路径
|
||||
calculate_type: 扩胞计算精度 ('high', 'normal', 'low', 'very_low')
|
||||
max_expansion_factor: 最大扩胞因子
|
||||
keep_number: 保留的扩胞结构数量
|
||||
target_cation: 目标阳离子
|
||||
"""
|
||||
self.valence_yaml_path = valence_yaml_path or self.DEFAULT_VALENCE_PATH
|
||||
self.calculate_type = calculate_type
|
||||
self.max_expansion_factor = max_expansion_factor
|
||||
self.keep_number = keep_number
|
||||
self.target_cation = target_cation
|
||||
self.explict_element = [target_cation, f"{target_cation}+"]
|
||||
|
||||
# 加载化合价配置
|
||||
self.valences = self._load_valences()
|
||||
|
||||
def _load_valences(self) -> Dict[str, int]:
|
||||
"""加载化合价配置"""
|
||||
if os.path.exists(self.valence_yaml_path):
|
||||
with open(self.valence_yaml_path, 'r') as f:
|
||||
return yaml.safe_load(f)
|
||||
return {}
|
||||
|
||||
def process_file(
|
||||
self,
|
||||
input_path: str,
|
||||
output_dir: str,
|
||||
needs_expansion: bool = False
|
||||
) -> ProcessingResult:
|
||||
"""
|
||||
处理单个CIF文件
|
||||
|
||||
Args:
|
||||
input_path: 输入文件路径
|
||||
output_dir: 输出目录
|
||||
needs_expansion: 是否需要扩胞
|
||||
|
||||
Returns:
|
||||
ProcessingResult: 处理结果
|
||||
"""
|
||||
result = ProcessingResult(input_file=input_path)
|
||||
|
||||
try:
|
||||
# 读取结构
|
||||
structure = Structure.from_file(input_path)
|
||||
base_name = os.path.splitext(os.path.basename(input_path))[0]
|
||||
|
||||
# 检查是否需要扩胞
|
||||
occupation_list = self._process_cif_file(structure)
|
||||
|
||||
if occupation_list and needs_expansion:
|
||||
# 需要扩胞处理
|
||||
result.needs_expansion = True
|
||||
output_files = self._expand_and_save(
|
||||
structure, occupation_list, base_name, output_dir
|
||||
)
|
||||
result.output_files = output_files
|
||||
result.expansion_factor = occupation_list[0].get('denominator', 1) if occupation_list else 1
|
||||
else:
|
||||
# 不需要扩胞,直接添加化合价
|
||||
output_path = os.path.join(output_dir, f"{base_name}.cif")
|
||||
self._add_oxidation_states(structure)
|
||||
structure.to(filename=output_path)
|
||||
result.output_files = [output_path]
|
||||
|
||||
result.success = True
|
||||
|
||||
except Exception as e:
|
||||
result.success = False
|
||||
result.error_message = str(e)
|
||||
|
||||
return result
|
||||
|
||||
def _process_cif_file(self, structure: Structure) -> List[Dict]:
|
||||
"""
|
||||
统计结构中各原子的occupation情况
|
||||
"""
|
||||
occupation_dict = defaultdict(list)
|
||||
split_dict = {}
|
||||
|
||||
for i, site in enumerate(structure):
|
||||
occu = self._get_occu(site.species_string)
|
||||
|
||||
if occu != 1.0:
|
||||
if site.species.chemical_system not in self.explict_element:
|
||||
occupation_dict[occu].append(i + 1)
|
||||
|
||||
# 提取元素名称列表
|
||||
elements = []
|
||||
if ':' in site.species_string:
|
||||
parts = site.species_string.split(',')
|
||||
for part in parts:
|
||||
element_with_valence = part.strip().split(':')[0].strip()
|
||||
element_match = re.match(r'([A-Z][a-z]?)', element_with_valence)
|
||||
if element_match:
|
||||
elements.append(element_match.group(1))
|
||||
else:
|
||||
element_match = re.match(r'([A-Z][a-z]?)', site.species_string)
|
||||
if element_match:
|
||||
elements = [element_match.group(1)]
|
||||
|
||||
split_dict[occu] = elements
|
||||
|
||||
# 转换为列表格式
|
||||
occupation_list = [
|
||||
{
|
||||
"occupation": occu,
|
||||
"atom_serial": serials,
|
||||
"numerator": None,
|
||||
"denominator": None,
|
||||
"split": split_dict.get(occu, [])
|
||||
}
|
||||
for occu, serials in occupation_dict.items()
|
||||
]
|
||||
|
||||
return occupation_list
|
||||
|
||||
def _get_occu(self, s_str: str) -> float:
|
||||
"""从物种字符串获取占据率"""
|
||||
if not s_str.strip():
|
||||
return 1.0
|
||||
|
||||
pattern = r'([A-Za-z0-9+-]+):([0-9.]+)'
|
||||
matches = re.findall(pattern, s_str)
|
||||
|
||||
for species, occu in matches:
|
||||
if species not in self.explict_element:
|
||||
try:
|
||||
return float(occu)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
return 1.0
|
||||
|
||||
def _calculate_expansion_factor(self, occupation_list: List[Dict]) -> Tuple[int, List[Dict]]:
|
||||
"""计算扩胞因子"""
|
||||
if not occupation_list:
|
||||
return 1, []
|
||||
|
||||
precision_limits = {
|
||||
'high': None,
|
||||
'normal': 100,
|
||||
'low': 10,
|
||||
'very_low': 5
|
||||
}
|
||||
|
||||
limit = precision_limits.get(self.calculate_type)
|
||||
|
||||
for entry in occupation_list:
|
||||
occu = entry["occupation"]
|
||||
if limit:
|
||||
fraction = Fraction(occu).limit_denominator(limit)
|
||||
else:
|
||||
fraction = Fraction(occu).limit_denominator()
|
||||
|
||||
entry["numerator"] = fraction.numerator
|
||||
entry["denominator"] = fraction.denominator
|
||||
|
||||
# 计算最小公倍数
|
||||
denominators = [entry["denominator"] for entry in occupation_list]
|
||||
lcm = reduce(lambda a, b: a * b // math.gcd(a, b), denominators, 1)
|
||||
|
||||
# 统一分母
|
||||
for entry in occupation_list:
|
||||
denominator = entry["denominator"]
|
||||
entry["numerator"] = entry["numerator"] * (lcm // denominator)
|
||||
entry["denominator"] = lcm
|
||||
|
||||
return lcm, occupation_list
|
||||
|
||||
def _expand_and_save(
|
||||
self,
|
||||
structure: Structure,
|
||||
occupation_list: List[Dict],
|
||||
base_name: str,
|
||||
output_dir: str
|
||||
) -> List[str]:
|
||||
"""扩胞并保存"""
|
||||
lcm, oc_list = self._calculate_expansion_factor(occupation_list)
|
||||
|
||||
if lcm > self.max_expansion_factor:
|
||||
raise ValueError(f"扩胞因子 {lcm} 超过最大限制 {self.max_expansion_factor}")
|
||||
|
||||
# 获取扩胞策略
|
||||
strategies = self._strategy_divide(structure, lcm)
|
||||
|
||||
if not strategies:
|
||||
raise ValueError("无法找到合适的扩胞策略")
|
||||
|
||||
# 生成结构列表
|
||||
st_list = self._generate_structure_list(structure, oc_list)
|
||||
|
||||
output_files = []
|
||||
keep_number = min(self.keep_number, len(strategies))
|
||||
|
||||
for index in range(keep_number):
|
||||
merged = self._merge_structures(st_list, strategies[index])
|
||||
|
||||
# 添加化合价
|
||||
self._add_oxidation_states(merged)
|
||||
|
||||
# 当只保存1个时不加后缀
|
||||
if keep_number == 1:
|
||||
output_filename = f"{base_name}.cif"
|
||||
else:
|
||||
suffix = "x{}y{}z{}".format(
|
||||
strategies[index]["x"],
|
||||
strategies[index]["y"],
|
||||
strategies[index]["z"]
|
||||
)
|
||||
output_filename = f"{base_name}-{suffix}.cif"
|
||||
|
||||
output_path = os.path.join(output_dir, output_filename)
|
||||
merged.to(filename=output_path, fmt="cif")
|
||||
output_files.append(output_path)
|
||||
|
||||
return output_files
|
||||
|
||||
def _add_oxidation_states(self, structure: Structure):
|
||||
"""添加化合价"""
|
||||
# 检查是否已有化合价
|
||||
has_oxidation = all(
|
||||
all(isinstance(sp, Specie) for sp in site.species.keys())
|
||||
for site in structure.sites
|
||||
)
|
||||
|
||||
if not has_oxidation and self.valences:
|
||||
structure.add_oxidation_state_by_element(self.valences)
|
||||
|
||||
def _strategy_divide(self, structure: Structure, total: int) -> List[Dict]:
|
||||
"""根据晶体类型确定扩胞策略"""
|
||||
try:
|
||||
space_group_info = structure.get_space_group_info()
|
||||
space_group_symbol = space_group_info[0]
|
||||
|
||||
# 获取空间群类型
|
||||
all_spacegroup_symbols = [spglib.get_spacegroup_type(i) for i in range(1, 531)]
|
||||
symbol = all_spacegroup_symbols[0]
|
||||
for symbol_i in all_spacegroup_symbols:
|
||||
if space_group_symbol == symbol_i.international_short:
|
||||
symbol = symbol_i
|
||||
break
|
||||
|
||||
space_type = self._typejudge(symbol.number)
|
||||
|
||||
if space_type == "Cubic":
|
||||
return self._factorize_to_three_factors(total, "xyz")
|
||||
else:
|
||||
return self._factorize_to_three_factors(total)
|
||||
except:
|
||||
return self._factorize_to_three_factors(total)
|
||||
|
||||
def _typejudge(self, number: int) -> str:
|
||||
"""判断晶体类型"""
|
||||
if number in [1, 2]:
|
||||
return "Triclinic"
|
||||
elif 3 <= number <= 15:
|
||||
return "Monoclinic"
|
||||
elif 16 <= number <= 74:
|
||||
return "Orthorhombic"
|
||||
elif 75 <= number <= 142:
|
||||
return "Tetragonal"
|
||||
elif 143 <= number <= 167:
|
||||
return "Trigonal"
|
||||
elif 168 <= number <= 194:
|
||||
return "Hexagonal"
|
||||
elif 195 <= number <= 230:
|
||||
return "Cubic"
|
||||
else:
|
||||
return "Unknown"
|
||||
|
||||
def _factorize_to_three_factors(self, n: int, type_sym: str = None) -> List[Dict]:
|
||||
"""分解为三个因子"""
|
||||
factors = []
|
||||
|
||||
if type_sym == "xyz":
|
||||
for x in range(1, n + 1):
|
||||
if n % x == 0:
|
||||
remaining_n = n // x
|
||||
for y in range(1, remaining_n + 1):
|
||||
if remaining_n % y == 0 and y <= x:
|
||||
z = remaining_n // y
|
||||
if z <= y:
|
||||
factors.append({'x': x, 'y': y, 'z': z})
|
||||
else:
|
||||
for x in range(1, n + 1):
|
||||
if n % x == 0:
|
||||
remaining_n = n // x
|
||||
for y in range(1, remaining_n + 1):
|
||||
if remaining_n % y == 0:
|
||||
z = remaining_n // y
|
||||
factors.append({'x': x, 'y': y, 'z': z})
|
||||
|
||||
# 排序
|
||||
def sort_key(item):
|
||||
return (item['x'] + item['y'] + item['z'], item['z'], item['y'], item['x'])
|
||||
|
||||
return sorted(factors, key=sort_key)
|
||||
|
||||
def _generate_structure_list(
|
||||
self,
|
||||
base_structure: Structure,
|
||||
occupation_list: List[Dict]
|
||||
) -> List[Structure]:
|
||||
"""生成结构列表"""
|
||||
if not occupation_list:
|
||||
return [base_structure.copy()]
|
||||
|
||||
lcm = occupation_list[0]["denominator"]
|
||||
structure_list = [base_structure.copy() for _ in range(lcm)]
|
||||
|
||||
for entry in occupation_list:
|
||||
numerator = entry["numerator"]
|
||||
denominator = entry["denominator"]
|
||||
atom_indices = entry["atom_serial"]
|
||||
|
||||
for atom_idx in atom_indices:
|
||||
occupancy_dict = self._mark_atoms_randomly(numerator, denominator)
|
||||
original_site = base_structure.sites[atom_idx - 1]
|
||||
element = self._get_first_non_explicit_element(original_site.species_string)
|
||||
|
||||
for copy_idx, occupy in occupancy_dict.items():
|
||||
structure_list[copy_idx].remove_sites([atom_idx - 1])
|
||||
oxi_state = self._extract_oxi_state(original_site.species_string, element)
|
||||
|
||||
if len(entry["split"]) == 1:
|
||||
if occupy:
|
||||
new_site = PeriodicSite(
|
||||
species=Species(element, oxi_state),
|
||||
coords=original_site.frac_coords,
|
||||
lattice=structure_list[copy_idx].lattice,
|
||||
to_unit_cell=True,
|
||||
label=original_site.label
|
||||
)
|
||||
else:
|
||||
species_dict = {Species(self.target_cation, 1.0): 0.0}
|
||||
new_site = PeriodicSite(
|
||||
species=species_dict,
|
||||
coords=original_site.frac_coords,
|
||||
lattice=structure_list[copy_idx].lattice,
|
||||
to_unit_cell=True,
|
||||
label=original_site.label
|
||||
)
|
||||
else:
|
||||
if occupy:
|
||||
new_site = PeriodicSite(
|
||||
species=Species(element, oxi_state),
|
||||
coords=original_site.frac_coords,
|
||||
lattice=structure_list[copy_idx].lattice,
|
||||
to_unit_cell=True,
|
||||
label=original_site.label
|
||||
)
|
||||
else:
|
||||
new_site = PeriodicSite(
|
||||
species=Species(entry['split'][1], oxi_state),
|
||||
coords=original_site.frac_coords,
|
||||
lattice=structure_list[copy_idx].lattice,
|
||||
to_unit_cell=True,
|
||||
label=original_site.label
|
||||
)
|
||||
|
||||
structure_list[copy_idx].sites.insert(atom_idx - 1, new_site)
|
||||
|
||||
return structure_list
|
||||
|
||||
def _mark_atoms_randomly(self, numerator: int, denominator: int) -> Dict[int, int]:
|
||||
"""随机标记原子"""
|
||||
if numerator > denominator:
|
||||
raise ValueError(f"numerator ({numerator}) 不能超过 denominator ({denominator})")
|
||||
|
||||
atom_dice = list(range(denominator))
|
||||
selected_atoms = random.sample(atom_dice, numerator)
|
||||
|
||||
return {atom: 1 if atom in selected_atoms else 0 for atom in atom_dice}
|
||||
|
||||
def _get_first_non_explicit_element(self, species_str: str) -> str:
|
||||
"""获取第一个非目标元素"""
|
||||
if not species_str.strip():
|
||||
return ""
|
||||
|
||||
species_parts = [part.strip() for part in species_str.split(",") if part.strip()]
|
||||
|
||||
for part in species_parts:
|
||||
element_with_charge = part.split(":")[0].strip()
|
||||
pure_element = ''.join([c for c in element_with_charge if c.isalpha()])
|
||||
|
||||
if pure_element not in self.explict_element:
|
||||
return pure_element
|
||||
|
||||
return ""
|
||||
|
||||
def _extract_oxi_state(self, species_str: str, element: str) -> int:
|
||||
"""提取氧化态"""
|
||||
species_parts = [part.strip() for part in species_str.split(",") if part.strip()]
|
||||
|
||||
for part in species_parts:
|
||||
element_with_charge = part.split(":")[0].strip()
|
||||
|
||||
if element in element_with_charge:
|
||||
charge_part = element_with_charge[len(element):]
|
||||
|
||||
if not any(c.isdigit() for c in charge_part):
|
||||
if "+" in charge_part:
|
||||
return 1
|
||||
elif "-" in charge_part:
|
||||
return -1
|
||||
else:
|
||||
return 0
|
||||
|
||||
sign = 1
|
||||
if "-" in charge_part:
|
||||
sign = -1
|
||||
|
||||
digits = ""
|
||||
for c in charge_part:
|
||||
if c.isdigit():
|
||||
digits += c
|
||||
|
||||
if digits:
|
||||
return sign * int(digits)
|
||||
|
||||
return 0
|
||||
|
||||
def _merge_structures(self, structure_list: List[Structure], merge_dict: Dict) -> Structure:
|
||||
"""合并结构"""
|
||||
if not structure_list:
|
||||
raise ValueError("结构列表不能为空")
|
||||
|
||||
ref_lattice = structure_list[0].lattice
|
||||
|
||||
total_merge = merge_dict.get("x", 1) * merge_dict.get("y", 1) * merge_dict.get("z", 1)
|
||||
if len(structure_list) != total_merge:
|
||||
raise ValueError(f"结构数量({len(structure_list)})与合并次数({total_merge})不匹配")
|
||||
|
||||
a, b, c = ref_lattice.abc
|
||||
alpha, beta, gamma = ref_lattice.angles
|
||||
|
||||
new_a = a * merge_dict.get("x", 1)
|
||||
new_b = b * merge_dict.get("y", 1)
|
||||
new_c = c * merge_dict.get("z", 1)
|
||||
new_lattice = Lattice.from_parameters(new_a, new_b, new_c, alpha, beta, gamma)
|
||||
|
||||
all_sites = []
|
||||
for i, structure in enumerate(structure_list):
|
||||
x_offset = (i // (merge_dict.get("y", 1) * merge_dict.get("z", 1))) % merge_dict.get("x", 1)
|
||||
y_offset = (i // merge_dict.get("z", 1)) % merge_dict.get("y", 1)
|
||||
z_offset = i % merge_dict.get("z", 1)
|
||||
|
||||
for site in structure:
|
||||
coords = site.frac_coords.copy()
|
||||
coords[0] = (coords[0] + x_offset) / merge_dict.get("x", 1)
|
||||
coords[1] = (coords[1] + y_offset) / merge_dict.get("y", 1)
|
||||
coords[2] = (coords[2] + z_offset) / merge_dict.get("z", 1)
|
||||
all_sites.append({"species": site.species, "coords": coords})
|
||||
|
||||
return Structure(
|
||||
new_lattice,
|
||||
[site["species"] for site in all_sites],
|
||||
[site["coords"] for site in all_sites]
|
||||
)
|
||||
|
||||
|
||||
def process_batch(
|
||||
input_files: List[str],
|
||||
output_dir: str,
|
||||
needs_expansion_flags: List[bool] = None,
|
||||
valence_yaml_path: str = None,
|
||||
calculate_type: str = 'low',
|
||||
target_cation: str = "Li",
|
||||
show_progress: bool = True
|
||||
) -> List[ProcessingResult]:
|
||||
"""
|
||||
批量处理CIF文件
|
||||
|
||||
Args:
|
||||
input_files: 输入文件列表
|
||||
output_dir: 输出目录
|
||||
needs_expansion_flags: 是否需要扩胞的标记列表
|
||||
valence_yaml_path: 化合价配置文件路径
|
||||
calculate_type: 扩胞计算精度
|
||||
target_cation: 目标阳离子
|
||||
show_progress: 是否显示进度
|
||||
|
||||
Returns:
|
||||
处理结果列表
|
||||
"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
processor = StructureProcessor(
|
||||
valence_yaml_path=valence_yaml_path,
|
||||
calculate_type=calculate_type,
|
||||
target_cation=target_cation
|
||||
)
|
||||
|
||||
if needs_expansion_flags is None:
|
||||
needs_expansion_flags = [False] * len(input_files)
|
||||
|
||||
results = []
|
||||
total = len(input_files)
|
||||
|
||||
for i, (input_file, needs_exp) in enumerate(zip(input_files, needs_expansion_flags)):
|
||||
if show_progress:
|
||||
print(f"\r处理进度: {i+1}/{total} - {os.path.basename(input_file)}", end="")
|
||||
|
||||
result = processor.process_file(input_file, output_dir, needs_exp)
|
||||
results.append(result)
|
||||
|
||||
if show_progress:
|
||||
print()
|
||||
|
||||
return results
|
||||
0
src/preprocessing/script_generator.py
Normal file
0
src/preprocessing/script_generator.py
Normal file
0
src/preprocessing/validator.py
Normal file
0
src/preprocessing/validator.py
Normal file
0
src/utils/__init__.py
Normal file
0
src/utils/__init__.py
Normal file
0
src/utils/io.py
Normal file
0
src/utils/io.py
Normal file
0
src/utils/logger.py
Normal file
0
src/utils/logger.py
Normal file
0
src/utils/structure.py
Normal file
0
src/utils/structure.py
Normal file
5
tool/Li/Br+O.yaml
Normal file
5
tool/Li/Br+O.yaml
Normal file
@@ -0,0 +1,5 @@
|
||||
SPECIE: Li+
|
||||
ANION: Br
|
||||
PERCO_R: 0.45
|
||||
NEIGHBOR: 1.8
|
||||
LONG: 2.2
|
||||
5
tool/Li/Cl+O.yaml
Normal file
5
tool/Li/Cl+O.yaml
Normal file
@@ -0,0 +1,5 @@
|
||||
SPECIE: Li+
|
||||
ANION: Cl
|
||||
PERCO_R: 0.45
|
||||
NEIGHBOR: 1.8
|
||||
LONG: 2.2
|
||||
Reference in New Issue
Block a user