Compare commits

18 Commits

Author SHA1 Message Date
koko
f78298e803 一阶段高筛制作完成 2025-12-16 11:36:49 +08:00
6ea96c81d6 增加扩胞逻辑 2025-12-14 18:33:54 +08:00
9bde3e1229 增加扩胞逻辑 2025-12-14 18:11:00 +08:00
83647c2218 增加扩胞逻辑 2025-12-14 18:01:11 +08:00
9b36aa10ff 增加扩胞逻辑 2025-12-14 17:57:42 +08:00
2378a3f2a2 增加扩胞逻辑 2025-12-14 16:59:01 +08:00
da8c85b830 增加扩胞逻辑 2025-12-14 16:57:57 +08:00
72cf0a79e1 增加扩胞逻辑 2025-12-14 16:52:14 +08:00
f27fd3e3ce 预处理增加并行计算 2025-12-14 15:53:11 +08:00
1fee324c90 预处理增加并行计算 2025-12-14 15:47:27 +08:00
ae4e7280b4 预处理增加并行计算 2025-12-14 15:42:13 +08:00
c91998662a 重构预处理制作 2025-12-14 15:10:18 +08:00
6eeb40d222 重构预处理制作 2025-12-14 14:34:26 +08:00
da26e0c619 CSM及TET,CS 2025-12-14 12:57:34 +08:00
cea5ab6d3f CSM及TET,CS 2025-12-07 22:30:46 +08:00
e885893484 CSM及TET,CS 2025-12-07 22:19:50 +08:00
3d44b31194 CSM及TET,CS 2025-12-07 20:08:19 +08:00
08f5a51fc4 开始加入CSM值计算 2025-12-07 17:55:25 +08:00
52 changed files with 6295 additions and 107 deletions

21
.gitignore vendored Normal file
View File

@@ -0,0 +1,21 @@
# --- Python 通用忽略 ---
__pycache__/
*.pyc
*.pyo
*.pyd
.Python
env/
venv/
.env
.venv/
# --- VS Code 配置 (可选,建议忽略) ---
.vscode/
# --- JetBrains (你之前的配置) ---
.idea/
/shelf/
/workspace.xml
/httpRequests/
/dataSources/
/dataSources.local.xml

8
.idea/.gitignore generated vendored
View File

@@ -1,8 +0,0 @@
# 默认忽略的文件
/shelf/
/workspace.xml
# 基于编辑器的 HTTP 客户端请求
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

8
.idea/Screen.iml generated
View File

@@ -1,8 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="C:\Users\PC20230606\Miniconda3" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>

View File

@@ -1,23 +0,0 @@
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
<option name="ignoredPackages">
<value>
<list size="9">
<item index="0" class="java.lang.String" itemvalue="torchvision" />
<item index="1" class="java.lang.String" itemvalue="torch" />
<item index="2" class="java.lang.String" itemvalue="tqdm" />
<item index="3" class="java.lang.String" itemvalue="scipy" />
<item index="4" class="java.lang.String" itemvalue="h5py" />
<item index="5" class="java.lang.String" itemvalue="matplotlib" />
<item index="6" class="java.lang.String" itemvalue="numpy" />
<item index="7" class="java.lang.String" itemvalue="opencv_python" />
<item index="8" class="java.lang.String" itemvalue="Pillow" />
</list>
</value>
</option>
</inspection_tool>
</profile>
</component>

View File

@@ -1,6 +0,0 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>

8
.idea/modules.xml generated
View File

@@ -1,8 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/Screen.iml" filepath="$PROJECT_DIR$/.idea/Screen.iml" />
</modules>
</component>
</project>

6
.idea/vcs.xml generated
View File

@@ -1,6 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$" vcs="Git" />
</component>
</project>

0
config/settings.yaml Normal file
View File

View File

643
main.py Normal file
View File

@@ -0,0 +1,643 @@
"""
高通量筛选与扩胞项目 - 主入口(支持断点续做)
"""
import os
import sys
import json
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
from src.analysis.database_analyzer import DatabaseAnalyzer
from src.analysis.report_generator import ReportGenerator
from src.core.executor import TaskExecutor
from src.preprocessing.processor import StructureProcessor
from src.computation.workspace_manager import WorkspaceManager
from src.computation.zeo_executor import ZeoExecutor, ZeoConfig
from src.computation.result_processor import ResultProcessor, FilterCriteria
def print_banner():
print("""
╔═══════════════════════════════════════════════════════════════════╗
║ 高通量筛选与扩胞项目 - 数据库分析工具 v2.2 ║
║ 支持断点续做与高性能并行计算 ║
╚═══════════════════════════════════════════════════════════════════╝
""")
def detect_and_show_environment():
"""检测并显示环境信息"""
env = TaskExecutor.detect_environment()
print("【运行环境检测】")
print(f" 主机名: {env['hostname']}")
print(f" 本地CPU核数: {env['total_cores']}")
print(f" SLURM集群: {'✅ 可用' if env['has_slurm'] else '❌ 不可用'}")
if env['has_slurm'] and env['slurm_partitions']:
print(f" 可用分区: {', '.join(env['slurm_partitions'])}")
if env['conda_env']:
print(f" 当前Conda: {env['conda_env']}")
return env
def detect_workflow_status(workspace_path: str = "workspace", target_cation: str = "Li") -> dict:
"""
检测工作流程状态,确定可以从哪一步继续
Returns:
dict: {
'has_processed_data': bool, # 是否有扩胞处理后的数据
'has_zeo_results': bool, # 是否有 Zeo++ 计算结果
'total_structures': int, # 总结构数
'structures_with_log': int, # 有 log.txt 的结构数
'workspace_info': object, # 工作区信息
'ws_manager': object # 工作区管理器
}
"""
status = {
'has_processed_data': False,
'has_zeo_results': False,
'total_structures': 0,
'structures_with_log': 0,
'workspace_info': None,
'ws_manager': None
}
# 创建工作区管理器
ws_manager = WorkspaceManager(
workspace_path=workspace_path,
tool_dir="tool",
target_cation=target_cation
)
status['ws_manager'] = ws_manager
# 检查是否有处理后的数据
data_dir = os.path.join(workspace_path, "data")
if os.path.exists(data_dir):
existing = ws_manager.check_existing_workspace()
if existing and existing.total_structures > 0:
status['has_processed_data'] = True
status['total_structures'] = existing.total_structures
status['workspace_info'] = existing
# 检查有多少结构有 log.txt即已完成 Zeo++ 计算)
log_count = 0
for anion_key in os.listdir(data_dir):
anion_dir = os.path.join(data_dir, anion_key)
if not os.path.isdir(anion_dir):
continue
for struct_name in os.listdir(anion_dir):
struct_dir = os.path.join(anion_dir, struct_name)
if os.path.isdir(struct_dir):
log_path = os.path.join(struct_dir, "log.txt")
if os.path.exists(log_path):
log_count += 1
status['structures_with_log'] = log_count
# 如果大部分结构都有 log.txt认为 Zeo++ 计算已完成
if log_count > 0 and log_count >= status['total_structures'] * 0.5:
status['has_zeo_results'] = True
return status
def print_workflow_status(status: dict):
"""打印工作流程状态"""
print("\n" + "" * 50)
print("【工作流程状态检测】")
print("" * 50)
if not status['has_processed_data']:
print(" Step 1 (扩胞+化合价): ❌ 未完成")
print(" Step 2-4 (Zeo++ 计算): ❌ 未完成")
print(" Step 5 (结果筛选): ❌ 未完成")
else:
print(f" Step 1 (扩胞+化合价): ✅ 已完成 ({status['total_structures']} 个结构)")
if status['has_zeo_results']:
print(f" Step 2-4 (Zeo++ 计算): ✅ 已完成 ({status['structures_with_log']}/{status['total_structures']} 有日志)")
print(" Step 5 (结果筛选): ⏳ 可执行")
else:
print(f" Step 2-4 (Zeo++ 计算): ⏳ 可执行 ({status['structures_with_log']}/{status['total_structures']} 有日志)")
print(" Step 5 (结果筛选): ❌ 需先完成 Zeo++ 计算")
print("" * 50)
def get_user_choice(status: dict) -> str:
"""
根据工作流程状态获取用户选择
Returns:
'step1': 从头开始(数据库分析 + 扩胞)
'step2': 从 Zeo++ 计算开始
'step5': 直接进行结果筛选
'exit': 退出
"""
print("\n请选择操作:")
options = []
if status['has_zeo_results']:
options.append(('5', '直接进行结果筛选 (Step 5)'))
options.append(('2', '重新运行 Zeo++ 计算 (Step 2-4)'))
options.append(('1', '从头开始 (数据库分析 + 扩胞)'))
elif status['has_processed_data']:
options.append(('2', '运行 Zeo++ 计算 (Step 2-4)'))
options.append(('1', '从头开始 (数据库分析 + 扩胞)'))
else:
options.append(('1', '从头开始 (数据库分析 + 扩胞)'))
options.append(('q', '退出'))
for key, desc in options:
print(f" {key}. {desc}")
choice = input("\n请选择 [默认: " + options[0][0] + "]: ").strip().lower()
if not choice:
choice = options[0][0]
if choice == 'q':
return 'exit'
elif choice == '5':
return 'step5'
elif choice == '2':
return 'step2'
else:
return 'step1'
def run_step1_database_analysis(env: dict, cation: str) -> dict:
"""
Step 1: 数据库分析与扩胞处理
Returns:
处理参数字典,如果用户取消则返回 None
"""
print("\n" + "" * 60)
print("【Step 1: 数据库分析与扩胞处理】")
print("" * 60)
# 数据库路径
while True:
db_path = input("\n📂 请输入数据库路径: ").strip()
if os.path.exists(db_path):
break
print(f"❌ 路径不存在: {db_path}")
# 检测当前Conda环境路径
default_conda = "/cluster/home/koko125/anaconda3/envs/screen"
conda_env_path = env.get('conda_env', '') or default_conda
print(f"\n检测到Conda环境: {conda_env_path}")
custom_env = input(f"使用此环境? [Y/n] 或输入其他路径: ").strip()
if custom_env.lower() == 'n':
conda_env_path = input("请输入Conda环境完整路径: ").strip()
elif custom_env and custom_env.lower() != 'y':
conda_env_path = custom_env
# 目标阴离子
anion_input = input("🎯 请输入目标阴离子 (逗号分隔) [默认: O,S,Cl,Br]: ").strip()
anions = set(a.strip() for a in anion_input.split(',')) if anion_input else {'O', 'S', 'Cl', 'Br'}
# 阴离子模式
print("\n阴离子模式:")
print(" 1. 仅单一阴离子")
print(" 2. 仅复合阴离子")
print(" 3. 全部 (默认)")
mode_choice = input("请选择 [1/2/3]: ").strip()
anion_mode = {'1': 'single', '2': 'mixed', '3': 'all', '': 'all'}.get(mode_choice, 'all')
# 并行配置
print("\n" + "" * 50)
print("【并行计算配置】")
default_cores = min(env['total_cores'], 32)
cores_input = input(f"💻 最大可用核数/Worker数 [默认: {default_cores}]: ").strip()
max_workers = int(cores_input) if cores_input.isdigit() else default_cores
params = {
'database_path': db_path,
'target_cation': cation,
'target_anions': anions,
'anion_mode': anion_mode,
'max_workers': max_workers,
'conda_env': conda_env_path,
}
print("\n" + "" * 60)
print("开始数据库分析...")
print("" * 60)
# 创建分析器
analyzer = DatabaseAnalyzer(
database_path=params['database_path'],
target_cation=params['target_cation'],
target_anions=params['target_anions'],
anion_mode=params['anion_mode'],
max_cores=params['max_workers'],
task_complexity='medium'
)
print(f"\n发现 {len(analyzer.cif_files)} 个CIF文件")
# 执行分析
report = analyzer.analyze(show_progress=True)
# 打印报告
ReportGenerator.print_report(report, detailed=True)
# 保存选项
save_choice = input("\n是否保存报告? [y/N]: ").strip().lower()
if save_choice == 'y':
output_path = input("报告路径 [默认: analysis_report.json]: ").strip()
output_path = output_path or "analysis_report.json"
report.save(output_path)
print(f"✅ 报告已保存到: {output_path}")
# CSV导出
csv_choice = input("是否导出详细CSV? [y/N]: ").strip().lower()
if csv_choice == 'y':
csv_path = input("CSV路径 [默认: analysis_details.csv]: ").strip()
csv_path = csv_path or "analysis_details.csv"
ReportGenerator.export_to_csv(report, csv_path)
# 生成最终数据库
process_choice = input("\n是否生成最终可用的数据库(扩胞+添加化合价)? [Y/n]: ").strip().lower()
if process_choice == 'n':
print("\n已跳过扩胞处理,可稍后继续")
return params
# 输出目录设置
print("\n输出目录设置:")
flat_dir = input(" 原始格式输出目录 [默认: workspace/processed]: ").strip()
flat_dir = flat_dir or "workspace/processed"
analysis_dir = input(" 分析格式输出目录 [默认: workspace/data]: ").strip()
analysis_dir = analysis_dir or "workspace/data"
# 扩胞保存数量
keep_input = input("\n扩胞结构保存数量 [默认: 1]: ").strip()
keep_number = int(keep_input) if keep_input.isdigit() and int(keep_input) > 0 else 1
# 扩胞精度选择
print("\n扩胞计算精度:")
print(" 1. 高精度 (精确分数)")
print(" 2. 普通精度 (分母≤100)")
print(" 3. 低精度 (分母≤10) [默认]")
print(" 4. 极低精度 (分母≤5)")
precision_choice = input("请选择 [1/2/3/4]: ").strip()
calculate_type = {
'1': 'high', '2': 'normal', '3': 'low', '4': 'very_low', '': 'low'
}.get(precision_choice, 'low')
# 获取可处理文件
processable = report.get_processable_files(include_needs_expansion=True)
if not processable:
print("⚠️ 没有可处理的文件")
return params
print(f"\n发现 {len(processable)} 个可处理的文件")
# 准备文件列表和扩胞标记
input_files = [info.file_path for info in processable]
needs_expansion_flags = [info.needs_expansion for info in processable]
anion_types_list = [info.anion_types for info in processable]
direct_count = sum(1 for f in needs_expansion_flags if not f)
expansion_count = sum(1 for f in needs_expansion_flags if f)
print(f" - 可直接处理: {direct_count}")
print(f" - 需要扩胞: {expansion_count}")
print(f" - 扩胞保存数: {keep_number}")
confirm = input("\n确认开始处理? [Y/n]: ").strip().lower()
if confirm == 'n':
print("\n已取消处理")
return params
print("\n" + "" * 60)
print("开始处理结构文件...")
print("" * 60)
# 创建处理器
processor = StructureProcessor(
calculate_type=calculate_type,
keep_number=keep_number,
target_cation=params['target_cation']
)
# 创建输出目录
os.makedirs(flat_dir, exist_ok=True)
os.makedirs(analysis_dir, exist_ok=True)
results = []
total = len(input_files)
import shutil
for i, (input_file, needs_exp, anion_types) in enumerate(
zip(input_files, needs_expansion_flags, anion_types_list)
):
print(f"\r处理进度: {i+1}/{total} - {os.path.basename(input_file)}", end="")
# 处理文件到原始格式目录
result = processor.process_file(input_file, flat_dir, needs_exp)
results.append(result)
if result.success:
# 同时保存到分析格式目录
# 按阴离子类型创建子目录
anion_key = '+'.join(sorted(anion_types)) if anion_types else 'other'
anion_dir = os.path.join(analysis_dir, anion_key)
os.makedirs(anion_dir, exist_ok=True)
# 获取基础文件名
base_name = os.path.splitext(os.path.basename(input_file))[0]
# 创建以文件名命名的子目录
file_dir = os.path.join(anion_dir, base_name)
os.makedirs(file_dir, exist_ok=True)
# 复制生成的文件到分析格式目录
for output_file in result.output_files:
dst_path = os.path.join(file_dir, os.path.basename(output_file))
shutil.copy2(output_file, dst_path)
print()
# 统计结果
success_count = sum(1 for r in results if r.success)
fail_count = sum(1 for r in results if not r.success)
total_output = sum(len(r.output_files) for r in results if r.success)
print("\n" + "-" * 60)
print("【处理结果统计】")
print("-" * 60)
print(f" 成功处理: {success_count}")
print(f" 处理失败: {fail_count}")
print(f" 生成文件: {total_output}")
print(f"\n 原始格式目录: {flat_dir}")
print(f" 分析格式目录: {analysis_dir}")
print(f" └── 结构: data/阴离子类型/文件名/文件名.cif")
# 显示失败的文件
if fail_count > 0:
print("\n失败的文件:")
for r in results:
if not r.success:
print(f" - {os.path.basename(r.input_file)}: {r.error_message}")
print("\n✅ Step 1 完成!")
return params
def run_step2_zeo_analysis(params: dict, ws_manager: WorkspaceManager = None, workspace_info = None):
"""
Step 2-4: Zeo++ Voronoi 分析
Args:
params: 参数字典,包含 target_cation 等
ws_manager: 工作区管理器(可选,如果已有则直接使用)
workspace_info: 工作区信息(可选,如果已有则直接使用)
"""
print("\n" + "" * 60)
print("【Step 2-4: Zeo++ Voronoi 分析】")
print("" * 60)
# 如果没有传入 ws_manager则创建新的
if ws_manager is None:
# 工作区路径
workspace_path = input("\n工作区路径 [默认: workspace]: ").strip() or "workspace"
tool_dir = input("工具目录路径 [默认: tool]: ").strip() or "tool"
# 创建工作区管理器
ws_manager = WorkspaceManager(
workspace_path=workspace_path,
tool_dir=tool_dir,
target_cation=params.get('target_cation', 'Li')
)
# 如果没有传入 workspace_info则检查现有工作区
if workspace_info is None:
existing = ws_manager.check_existing_workspace()
if existing and existing.total_structures > 0:
ws_manager.print_workspace_summary(existing)
workspace_info = existing
else:
print("⚠️ 工作区数据目录不存在或为空")
print(f" 请先运行 Step 1 生成处理后的数据到: {ws_manager.data_dir}")
return
# 检查是否需要创建软链接
if workspace_info.linked_structures < workspace_info.total_structures:
print("\n正在创建配置文件软链接...")
workspace_info = ws_manager.setup_workspace(force_relink=False)
# 获取计算任务
tasks = ws_manager.get_computation_tasks(workspace_info)
if not tasks:
print("⚠️ 没有找到可计算的任务")
return
print(f"\n发现 {len(tasks)} 个计算任务")
# Zeo++ 环境配置
print("\n" + "-" * 50)
print("【Zeo++ 计算配置】")
default_zeo_env = "/cluster/home/koko125/anaconda3/envs/zeo"
zeo_env = input(f"Zeo++ Conda环境 [默认: {default_zeo_env}]: ").strip()
zeo_env = zeo_env or default_zeo_env
# SLURM 配置
partition = input("SLURM分区 [默认: cpu]: ").strip() or "cpu"
max_concurrent_input = input("最大并发任务数 [默认: 50]: ").strip()
max_concurrent = int(max_concurrent_input) if max_concurrent_input.isdigit() else 50
time_limit = input("单任务时间限制 [默认: 2:00:00]: ").strip() or "2:00:00"
# 创建配置
zeo_config = ZeoConfig(
conda_env=zeo_env,
partition=partition,
max_concurrent=max_concurrent,
time_limit=time_limit
)
# 确认执行
print("\n" + "-" * 50)
print("【计算任务确认】")
print(f" 总任务数: {len(tasks)}")
print(f" Conda环境: {zeo_config.conda_env}")
print(f" SLURM分区: {zeo_config.partition}")
print(f" 最大并发: {zeo_config.max_concurrent}")
print(f" 时间限制: {zeo_config.time_limit}")
confirm = input("\n确认提交计算任务? [Y/n]: ").strip().lower()
if confirm == 'n':
print("已取消")
return
# 创建执行器并运行
executor = ZeoExecutor(zeo_config)
log_dir = os.path.join(ws_manager.workspace_path, "slurm_logs")
results = executor.run_batch(tasks, output_dir=log_dir)
# 打印结果摘要
executor.print_results_summary(results)
# 保存结果
results_file = os.path.join(ws_manager.workspace_path, "zeo_results.json")
results_data = [
{
'task_id': r.task_id,
'structure_name': r.structure_name,
'cif_path': r.cif_path,
'success': r.success,
'output_files': r.output_files,
'error_message': r.error_message
}
for r in results
]
with open(results_file, 'w') as f:
json.dump(results_data, f, indent=2)
print(f"\n结果已保存到: {results_file}")
print("\n✅ Step 2-4 完成!")
def run_step5_result_processing(workspace_path: str = "workspace"):
"""
Step 5: 结果处理与筛选
Args:
workspace_path: 工作区路径
"""
print("\n" + "" * 60)
print("【Step 5: 结果处理与筛选】")
print("" * 60)
# 创建结果处理器
processor = ResultProcessor(workspace_path=workspace_path)
# 获取筛选条件
print("\n设置筛选条件:")
print(" (直接回车使用默认值,输入 0 表示不限制)")
# 最小渗透直径(默认改为 1.0
perc_input = input(" 最小渗透直径 (Å) [默认: 1.0]: ").strip()
min_percolation = float(perc_input) if perc_input else 1.0
# 最小 d 值
d_input = input(" 最小 d 值 [默认: 2.0]: ").strip()
min_d = float(d_input) if d_input else 2.0
# 最大节点长度
node_input = input(" 最大节点长度 (Å) [默认: 不限制]: ").strip()
max_node = float(node_input) if node_input else float('inf')
# 创建筛选条件
criteria = FilterCriteria(
min_percolation_diameter=min_percolation,
min_d_value=min_d,
max_node_length=max_node
)
# 执行处理
results, stats = processor.process_and_filter(
criteria=criteria,
save_csv=True,
copy_passed=True
)
# 打印摘要
processor.print_summary(results, stats)
# 显示输出位置
print("\n输出文件位置:")
print(f" 汇总 CSV: {workspace_path}/results/summary.csv")
print(f" 分类 CSV: {workspace_path}/results/阴离子类型/阴离子类型.csv")
print(f" 通过筛选: {workspace_path}/passed/阴离子类型/结构名/")
print("\n✅ Step 5 完成!")
def main():
"""主函数"""
print_banner()
# 环境检测
env = detect_and_show_environment()
# 询问目标阳离子
cation = input("\n🎯 请输入目标阳离子 [默认: Li]: ").strip() or "Li"
# 检测工作流程状态
status = detect_workflow_status(workspace_path="workspace", target_cation=cation)
print_workflow_status(status)
# 获取用户选择
choice = get_user_choice(status)
if choice == 'exit':
print("\n👋 再见!")
return
# 根据选择执行相应步骤
params = {'target_cation': cation}
if choice == 'step1':
# 从头开始
params = run_step1_database_analysis(env, cation)
if params is None:
return
# 询问是否继续
continue_choice = input("\n是否继续进行 Zeo++ 计算? [Y/n]: ").strip().lower()
if continue_choice == 'n':
print("\n可稍后运行程序继续 Zeo++ 计算")
return
# 重新检测状态
status = detect_workflow_status(workspace_path="workspace", target_cation=cation)
run_step2_zeo_analysis(params, status['ws_manager'], status['workspace_info'])
# 询问是否继续筛选
continue_choice = input("\n是否继续进行结果筛选? [Y/n]: ").strip().lower()
if continue_choice == 'n':
print("\n可稍后运行程序继续结果筛选")
return
run_step5_result_processing("workspace")
elif choice == 'step2':
# 从 Zeo++ 计算开始
run_step2_zeo_analysis(params, status['ws_manager'], status['workspace_info'])
# 询问是否继续筛选
continue_choice = input("\n是否继续进行结果筛选? [Y/n]: ").strip().lower()
if continue_choice == 'n':
print("\n可稍后运行程序继续结果筛选")
return
run_step5_result_processing("workspace")
elif choice == 'step5':
# 直接进行结果筛选
run_step5_result_processing("workspace")
print("\n✅ 全部完成!")
if __name__ == "__main__":
main()

85
main_property.sh Normal file
View File

@@ -0,0 +1,85 @@
#!/bin/bash
# ==========================================
# 全流程自动化脚本 (直通筛选版)
# ==========================================
# 1. 环境初始化
echo "============ Stage 0: Initialization ============"
chmod -R u+w ../Screen
source $(conda info --base)/etc/profile.d/conda.sh
# 激活 screen 环境
conda activate ~/anaconda3/envs/screen
cd py/
export PYTHONPATH=$(pwd):$PYTHONPATH
# 2. 预处理与文件整理 (替代原 Step 1)
echo "============ Stage 1: File Organization (Direct Pass) ============"
# 运行预处理 (可选,确保 input 文件夹就绪)
python pre_process.py
# 运行直通版整理脚本
# 功能: 读取 input, 识别阴离子, 按结构复制到 after_step1/Anion/ID/ID.cif
# 跳过 check_basic 等耗时检查
python step1_direct.py
# 生成 Zeo++ 运行脚本
# 功能: 遍历 after_step1, 生成 analyze.sh 和 sh_all.sh
python make_sh.py
# 3. 运行 Zeo++ 计算
echo "============ Stage 2: Zeo++ Calculations ============"
conda deactivate
conda activate ~/anaconda3/envs/zeo
# 进入数据目录
cd ../data/after_step1
if [ -f "sh_all.sh" ]; then
# 执行所有计算
source sh_all.sh
# 清理总脚本 (可选)
# rm sh_all.sh
else
echo "Error: sh_all.sh not found! Please check Stage 1."
exit 1
fi
# 4. 数据提取与高级分析
echo "============ Stage 3: Data Extraction & Analysis ============"
# 切回 screen 环境
conda deactivate
conda activate ~/anaconda3/envs/screen
cd ../../py
# 3.1 提取 Zeo++ 基础数据
# 输出: ../output/Anion/Anion.csv (含 Perc, Min_d, Max_node)
python extract_data.py
# 3.2 计算角共享 (Corner Sharing)
# 输出: 更新 CSV, 增加 Is_Only_Corner_Sharing 列
echo "Running Corner Sharing Analysis..."
python analyze_cs.py
# 3.3 联合筛选
# 功能: 读取 CSV, 根据阈值筛选, 生成 ../data/after_screening 软链接/文件
python step2_4_combined.py
# 3.4 CSM 分析 (仅针对筛选后的材料)
# 输出: ../output/CSM/Anion/ID.dat
echo "Running CSM Analysis..."
python analyze_csm.py
# 3.5 统计四面体占据率
# 输出: 读取 .dat, 更新 CSV, 增加 Tet_Li_Ratio 列
echo "Updating Tetrahedral Li Ratio..."
python update_tet_occupancy.py
# 5. 结束
echo "========================================================"
echo "All tasks completed!"
echo "Results stored in:"
echo " - CSV Data: ../output/"
echo " - Screened: ../data/after_screening/"
echo " - CSM Details: ../output/CSM/"
echo "========================================================"

224
py/CSM_reconstruct.py Normal file
View File

@@ -0,0 +1,224 @@
import os
import sys
import numpy as np
import argparse
from tqdm import tqdm
from scipy.spatial import ConvexHull
from pymatgen.core import Structure
from pymatgen.core.periodic_table import Element
from pymatgen.analysis.chemenv.coordination_environments.coordination_geometry_finder import LocalGeometryFinder
# ================= 配置区域 =================
# 建议使用绝对路径,避免找不到文件夹
INPUT_DIR = "../../solidstate-tools/corner-sharing/data/1209/input" # 请确保这里有你的 .cif 文件
OUTPUT_DIR = "../output/CSM"
TARGET_ELEMENT = 'Li'
ENV_TYPE = 'both'
# ===========================================
class HiddenPrints:
'''用于隐藏 pymatgen 繁杂的输出'''
def __enter__(self):
self._original_stdout = sys.stdout
sys.stdout = open(os.devnull, 'w')
def __exit__(self, exc_type, exc_val, exc_tb):
sys.stdout.close()
sys.stdout = self._original_stdout
def non_elements(struct):
"""
【关键修复】保留卤素(F, Cl, Br, I) 和其他阴离子,防止氯化物结构被清空。
"""
# 这里加入了 F, Cl, Br, I, P, Se, Te 等
anions_to_keep = {"O", "S", "N", "F", "Cl", "Br", "I", "P", "Se", "Te", "As", "Sb", "C"}
stripped = struct.copy()
species_to_remove = [el.symbol for el in stripped.composition.elements
if el.symbol not in anions_to_keep]
if species_to_remove:
stripped.remove_species(species_to_remove)
return stripped
def site_env(coord, struct, sp="Li", envtype='both'):
stripped = non_elements(struct)
# 如果剥离后结构为空(例如纯金属锂),直接返回
if len(stripped) == 0:
return {'csm': np.nan, 'vol': np.nan, 'type': 'Error_NoAnions'}
with_li = stripped.copy()
# 插入一个探测用的 Li 原子
with_li.append(sp, coord, coords_are_cartesian=False, validate_proximity=False)
# 尝试排序,如果因为部分占据导致排序失败,则使用原始顺序
try:
with_li = with_li.get_sorted_structure()
except:
pass
tet_oct_competition = []
# ---------------- 四面体 (Tet) 检测 ----------------
if envtype == 'both' or envtype == 'tet':
for dist in np.linspace(1, 4, 601): # 扫描距离 1A 到 4A
neigh = with_li.get_neighbors(with_li.sites[0], dist)
if len(neigh) < 4:
continue
elif len(neigh) > 4:
break
neigh_coords = [i.coords for i in neigh]
try:
with HiddenPrints():
lgf = LocalGeometryFinder(only_symbols=["T:4"])
lgf.setup_structure(structure=with_li)
lgf.setup_local_geometry(isite=0, coords=neigh_coords)
site_volume = ConvexHull(neigh_coords).volume
# 获取 CSM
csm_val = lgf.get_coordination_symmetry_measures()['T:4']['csm']
tet_env = {'csm': csm_val, 'vol': site_volume, 'type': 'tet'}
tet_oct_competition.append(tet_env)
except Exception:
pass
if len(neigh) == 4: break
# ---------------- 八面体 (Oct) 检测 ----------------
if envtype == 'both' or envtype == 'oct':
for dist in np.linspace(1, 4, 601):
neigh = with_li.get_neighbors(with_li.sites[0], dist)
if len(neigh) < 6:
continue
elif len(neigh) > 6:
break
neigh_coords = [i.coords for i in neigh]
try:
with HiddenPrints():
lgf = LocalGeometryFinder(only_symbols=["O:6"], permutations_safe_override=False)
lgf.setup_structure(structure=with_li)
lgf.setup_local_geometry(isite=0, coords=neigh_coords)
site_volume = ConvexHull(neigh_coords).volume
csm_val = lgf.get_coordination_symmetry_measures()['O:6']['csm']
oct_env = {'csm': csm_val, 'vol': site_volume, 'type': 'oct'}
tet_oct_competition.append(oct_env)
except Exception:
pass
if len(neigh) == 6: break
# ---------------- 结果判定 ----------------
if len(tet_oct_competition) == 0:
return {'csm': np.nan, 'vol': np.nan, 'type': 'Non_' + envtype}
elif len(tet_oct_competition) == 1:
return tet_oct_competition[0]
elif len(tet_oct_competition) >= 2:
return min(tet_oct_competition, key=lambda x: x['csm'])
def extract_sites(struct, sp="Li", envtype='both'):
envlist = []
# 遍历所有位点寻找 Li
for i, site in enumerate(struct):
site_elements = [el.symbol for el in site.species.elements]
if sp in site_elements:
try:
# 传入结构副本以防修改原结构
singleenv = site_env(site.frac_coords, struct.copy(), sp, envtype)
envlist.append({
'site_index': i,
'frac_coords': site.frac_coords,
'type': singleenv.get('type', 'unknown'),
'csm': singleenv.get('csm', np.nan),
'volume': singleenv.get('vol', np.nan)
})
except Exception as e:
# 捕捉单个位点计算错误,不中断程序
# print(f" [Warn] Site {i} calculation failed: {e}")
pass
return envlist
def export_envs(envlist, sp, envtype, fname):
with open(fname, 'w') as f:
f.write('List of environment information\n')
f.write(f'Species : {sp}\n')
f.write(f'Envtype : {envtype}\n')
for item in envlist:
# 格式化输出,确保没有数据也能看懂
f.write(f"Site index {item['site_index']}: {item}\n")
# ================= 主程序 =================
def run_csm_analysis():
# 1. 检查目录
if not os.path.exists(INPUT_DIR):
print(f"错误: 输入目录不存在 -> {os.path.abspath(INPUT_DIR)}")
return
cif_files = []
for root, dirs, files in os.walk(INPUT_DIR):
for file in files:
if file.endswith(".cif"):
cif_files.append(os.path.join(root, file))
if not cif_files:
print(f"{INPUT_DIR} 中未找到 .cif 文件。")
return
print(f"开始分析 {len(cif_files)} 个文件 (目标元素: {TARGET_ELEMENT}, 包含阴离子: F,Cl,Br,I,O,S,N...)")
success_count = 0
for cif_path in tqdm(cif_files, desc="Calculating CSM"):
try:
# 准备路径
rel_path = os.path.relpath(cif_path, INPUT_DIR)
rel_dir = os.path.dirname(rel_path)
file_base = os.path.splitext(os.path.basename(cif_path))[0]
target_dir = os.path.join(OUTPUT_DIR, rel_dir)
if not os.path.exists(target_dir):
os.makedirs(target_dir)
target_dat_path = os.path.join(target_dir, f"{file_base}.dat")
# 如果文件已存在且不为空,可选择跳过
# if os.path.exists(target_dat_path) and os.path.getsize(target_dat_path) > 0:
# continue
# 读取结构
struct = Structure.from_file(cif_path)
# 检查是否含 Li
if Element(TARGET_ELEMENT) not in struct.composition.elements:
continue
# 计算环境
env_list = extract_sites(struct, sp=TARGET_ELEMENT, envtype=ENV_TYPE)
# 写入结果 (即使 env_list 为空也写入一个标记文件方便debug)
if env_list:
export_envs(env_list, sp=TARGET_ELEMENT, envtype=ENV_TYPE, fname=target_dat_path)
success_count += 1
else:
with open(target_dat_path, 'w') as f:
f.write(f"No {TARGET_ELEMENT} environments found (Check connectivity or anion types).")
except Exception as e:
print(f"\n[Error] File: {os.path.basename(cif_path)} -> {e}")
continue
print(f"\n分析完成!成功生成 {success_count} 个文件。")
print(f"输出目录: {os.path.abspath(OUTPUT_DIR)}")
if __name__ == "__main__":
run_csm_analysis()

118
py/CS_catulate.py Normal file
View File

@@ -0,0 +1,118 @@
import os
import pandas as pd
from pymatgen.core import Structure
# 确保你的 utils 文件夹在 py 目录下,并且包含 CS_analyse.py
from utils.CS_analyse import CS_catulate, check_only_corner_sharing
from tqdm import tqdm
# 配置路径
CSV_ROOT_DIR = "../output"
DATA_SOURCE_DIR = "../data/after_step1"
def get_cif_path(group_name, anion_name, material_id):
"""
根据 CSV 的层级信息构建 CIF 文件的绝对路径
"""
# 构建路径: ../data/after_step1/Group/Anion/ID/ID.cif
# 注意处理单阴离子情况 (Group == Anion)
if group_name == anion_name:
# 路径: ../data/after_step1/S/123/123.cif
rel_path = os.path.join(DATA_SOURCE_DIR, group_name, material_id, f"{material_id}.cif")
else:
# 路径: ../data/after_step1/S+O/S/123/123.cif
rel_path = os.path.join(DATA_SOURCE_DIR, group_name, anion_name, material_id, f"{material_id}.cif")
return os.path.abspath(rel_path)
def process_single_csv(csv_path, group_name, anion_name):
"""
处理单个 CSV 文件:读取 -> 计算角共享 -> 添加列 -> 保存
"""
print(f"正在处理 CSV: {csv_path}")
# 读取 CSV强制 ID 为字符串
try:
df = pd.read_csv(csv_path, dtype={'Filename': str})
except Exception as e:
print(f"读取 CSV 失败: {e}")
return
# 检查是否已经存在该列,如果存在且想重新计算,可以先删除,或者跳过
if 'Is_Only_Corner_Sharing' in df.columns:
print(" - 'Is_Only_Corner_Sharing' 列已存在,将覆盖更新。")
results = []
# 使用 tqdm 显示进度
for index, row in tqdm(df.iterrows(), total=df.shape[0], desc=f"Analyzing {anion_name}"):
material_id = str(row['Filename']).replace('.0', '')
cif_path = get_cif_path(group_name, anion_name, material_id)
cs_result = None # 默认值
if os.path.exists(cif_path):
try:
# 1. 加载结构
struct = Structure.from_file(cif_path)
# 2. 计算共享关系 (默认检测 Li 和常见阴离子)
# 你可以根据需要调整 anion 列表,或者动态使用 anion_name
target_anions = ['O', 'S', 'Cl', 'F', 'Br', 'I', 'N', 'P']
sharing_details = CS_catulate(struct, sp='Li', anion=target_anions)
# 3. 判断是否仅角共享 (返回 1 或 0 或 True/False)
# 根据你提供的截图,似乎是返回 0 或 1
is_only_corner = check_only_corner_sharing(sharing_details)
cs_result = is_only_corner
except Exception as e:
# print(f"计算出错 {material_id}: {e}")
cs_result = "Error"
else:
print(f" - 警告: 找不到 CIF 文件 {cif_path}")
cs_result = "File_Not_Found"
results.append(cs_result)
# 将结果添加为新列
df['Is_Only_Corner_Sharing'] = results
# 保存覆盖原文件
df.to_csv(csv_path, index=False)
print(f" - 已更新 CSV: {csv_path}")
def run_cs_analysis():
"""
遍历所有 CSV 并运行分析
"""
if not os.path.exists(CSV_ROOT_DIR):
print(f"CSV 根目录不存在: {CSV_ROOT_DIR}")
return
for root, dirs, files in os.walk(CSV_ROOT_DIR):
for file in files:
if file.endswith(".csv"):
csv_path = os.path.join(root, file)
# 解析 Group 和 Anion (用于定位 CIF)
rel_root = os.path.relpath(root, CSV_ROOT_DIR)
path_parts = rel_root.split(os.sep)
if len(path_parts) == 1:
group_name = path_parts[0]
anion_name = path_parts[0]
elif len(path_parts) >= 2:
group_name = path_parts[0]
anion_name = path_parts[1]
else:
continue
process_single_csv(csv_path, group_name, anion_name)
if __name__ == "__main__":
run_cs_analysis()

90
py/csm.py Normal file
View File

@@ -0,0 +1,90 @@
import os
from pymatgen.core import Structure
from pymatgen.core.periodic_table import Element
# 导入你的CSM计算工具库 (根据 provided context [11])
try:
from utils.analyze_env_st import extract_sites, export_envs
except ImportError:
print("Error: 找不到 utils.analyze_env_st 模块,请检查 utils 文件夹。")
exit()
from tqdm import tqdm
# ================= 配置区域 =================
# 输入目录:使用筛选后的目录,只计算符合要求的材料
INPUT_DIR = "../../solidstate-tools/corner-sharing/data/1209/input"
# 输出目录
OUTPUT_DIR = "../output/CSM"
# 分析参数
TARGET_ELEMENT = 'Na'
ENV_TYPE = 'both' # 可选 'tet', 'oct', 'both'
# ===========================================
def run_csm_analysis():
"""
遍历 after_screening 文件夹,计算 CSM 并生成 .dat 文件到 output/CSM
"""
if not os.path.exists(INPUT_DIR):
print(f"输入目录不存在: {INPUT_DIR},请先运行筛选步骤。")
return
# 收集所有需要处理的 CIF 文件
cif_files = []
for root, dirs, files in os.walk(INPUT_DIR):
for file in files:
if file.endswith(".cif"):
# 保存完整路径
cif_files.append(os.path.join(root, file))
print(f"开始进行 CSM 分析,共找到 {len(cif_files)} 个筛选后的材料...")
for cif_path in tqdm(cif_files, desc="Calculating CSM"):
try:
# 1. 确定输出路径,保持目录结构
# 获取相对路径 (例如: S/195819.cif 或 S+O/S/195819.cif)
rel_path = os.path.relpath(cif_path, INPUT_DIR)
# 获取所在文件夹 (例如: S 或 S+O/S)
rel_dir = os.path.dirname(rel_path)
# 获取文件名 (例如: 195819)
file_base = os.path.splitext(os.path.basename(cif_path))[0]
# 构建目标文件夹: ../output/CSM/S/
target_dir = os.path.join(OUTPUT_DIR, rel_dir)
if not os.path.exists(target_dir):
os.makedirs(target_dir)
# 构建目标文件路径: ../output/CSM/S/195819.dat
target_dat_path = os.path.join(target_dir, f"{file_base}.dat")
# 2. 如果已经存在,跳过 (可选,视需求而定,这里默认覆盖)
# if os.path.exists(target_dat_path):
# continue
# 3. 读取结构
struct = Structure.from_file(cif_path)
# 检查是否包含目标元素 (Li)
if Element(TARGET_ELEMENT) not in struct.composition.elements:
# print(f"Skipping {file_base}: No {TARGET_ELEMENT}")
continue
# 4. 计算 CSM (引用 utils 中的函数)
# extract_sites 返回环境列表
env_list = extract_sites(struct, sp=TARGET_ELEMENT, envtype=ENV_TYPE)
# 5. 导出结果 (引用 utils 中的函数)
# export_envs 将结果写入 .dat 文件
if env_list:
export_envs(env_list, sp=TARGET_ELEMENT, envtype=ENV_TYPE, fname=target_dat_path)
else:
# 如果没有提取到环境(例如没有配位环境),生成一个空文件或记录日志
with open(target_dat_path, 'w') as f:
f.write("No environments found.")
except Exception as e:
print(f"处理出错 {cif_path}: {e}")
print(f"CSM 分析完成,结果已保存至 {OUTPUT_DIR}")
if __name__ == "__main__":
run_csm_analysis()

View File

@@ -1,69 +1,113 @@
from pymatgen.core import Structure
from pymatgen.core.periodic_table import Element, Specie
from pymatgen.io.cif import CifWriter
from crystal_2 import crystal
import crystal_2
import os
import shutil
def get_anion_type(structure):
"""
判断阴离子类型。
仅识别 O, S, Cl, Br 及其组合。
其他非金属元素(如 P, N, F 等)将被忽略。
"""
# 仅保留这四种目标阴离子
valid_anions = {'O', 'S', 'Cl', 'Br'}
# 获取结构中的所有元素符号
elements = set([e.symbol for e in structure.composition.elements])
# 取交集找到当前结构包含的目标阴离子
found_anions = elements.intersection(valid_anions)
if not found_anions:
return "Unknown"
# 如果有多个阴离子,按字母顺序排序并用 '+' 连接
sorted_anions = sorted(list(found_anions))
return "+".join(sorted_anions)
def read_files_check_basic(folder_path):
file_contents = []
"""
读取 CIF 文件,进行基础检查 (check_basic)
通过筛选后按自定义阴离子规则分类并整理到 after_step1 文件夹。
"""
# 输出基础路径
output_base = "../data/after_step1"
if not os.path.exists(folder_path):
print(f"{folder_path} 文件夹不存在")
return file_contents
return
for filename in os.listdir(folder_path):
# 确保输出目录存在
if not os.path.exists(output_base):
os.makedirs(output_base)
cif_files = [f for f in os.listdir(folder_path) if f.endswith(".cif")]
print(f"{folder_path} 发现 {len(cif_files)} 个 CIF 文件,开始筛选与整理...")
count_pass = 0
for filename in cif_files:
file_path = os.path.join(folder_path, filename)
if os.path.isfile(file_path):
try:
temp = crystal(file_path)
file_contents.append(temp)
except Exception as e:
print(e)
continue # 如果出错跳过当前循环避免temp未定义报错
print(f"正在处理{filename}")
# 1. 调用 crystal_2 进行基础筛选
try:
temp = crystal(file_path)
# 进行基础检查 (电荷平衡、化学式检查等)
temp.check_basic()
if temp.check_basic_result:
# 获取不带后缀的文件名,用于创建同名文件夹
file_base_name = os.path.splitext(filename)[0]
if not temp.check_basic_result:
print(f"Skipped: {filename} (未通过 check_basic)")
continue
if not "+" in temp.anion:
# 单一阴离子情况
# 路径变为: ../data/after_step1/Anion/FileBaseName/
base_anion_folder = os.path.join("../data/after_step1", f"{temp.anion}")
target_folder = os.path.join(base_anion_folder, file_base_name)
except Exception as e:
print(f"Error checking {filename}: {e}")
continue
# 2. 筛选通过,进行分类整理
try:
print(f"Processing: {filename} (Passed)")
count_pass += 1
# 为了确保分类逻辑与 Direct 版本一致,重新读取结构判断阴离子
# (忽略 crystal_2 内部可能基于 P/N 等元素的命名)
struct = Structure.from_file(file_path)
anion_type = get_anion_type(struct)
# 获取不带后缀的文件名 (ID)
file_base_name = os.path.splitext(filename)[0]
# --- 构建目标路径逻辑 (Anion/ID/ID.cif) ---
if "+" in anion_type:
# 混合阴离子情况 (如 S+O)
# 分别复制到 S+O/S 和 S+O/O 下
sub_anions = anion_type.split("+")
for sub in sub_anions:
# 路径: ../data/after_step1/S+O/S/123/123.cif
target_folder = os.path.join(output_base, anion_type, sub, file_base_name)
if not os.path.exists(target_folder):
os.makedirs(target_folder)
# 目标文件路径
target_file_path = os.path.join(target_folder, filename)
# 复制文件到目标文件夹
shutil.copy(file_path, target_file_path)
print(f"文件 {filename}通过基本筛选,已复制到 {target_folder}")
else:
# 混合阴离子情况
anions = temp.anion.split("+")
for anion in anions:
# 路径变为: ../data/after_step1/AnionCombination/Anion/FileBaseName/
base_group_folder = os.path.join("../data/after_step1", f"{temp.anion}")
base_anion_folder = os.path.join(base_group_folder, anion)
target_folder = os.path.join(base_anion_folder, file_base_name)
target_file = os.path.join(target_folder, filename)
shutil.copy(file_path, target_file)
else:
# 单一阴离子或 Unknown: ../data/after_step1/S/123/123.cif
target_folder = os.path.join(output_base, anion_type, file_base_name)
if not os.path.exists(target_folder):
os.makedirs(target_folder)
if not os.path.exists(target_folder):
os.makedirs(target_folder)
target_file = os.path.join(target_folder, filename)
shutil.copy(file_path, target_file)
# 目标文件路径
target_file_path = os.path.join(target_folder, filename)
# 复制文件到目标文件夹
shutil.copy(file_path, target_file_path)
print(f"文件 {filename}通过基本筛选,已复制到 {target_folder}")
except Exception as e:
print(f"Error copying {filename}: {e}")
print(f"处理完成。共 {len(cif_files)} 个文件,通过筛选 {count_pass} 个。")
if __name__ == "__main__":
read_files_check_basic("../data/input")
# 根据你的 readmeMP数据在 input_preICSD在 input
# 这里默认读取 input你可以根据实际情况修改
read_files_check_basic("../../solidstate-tools/corner-sharing/data/1209/input")

103
py/step1_direct.py Normal file
View File

@@ -0,0 +1,103 @@
import os
import shutil
from pymatgen.core import Structure
def get_anion_type(structure):
"""
判断阴离子类型。
仅识别 O, S, Cl, Br 及其组合。
其他非金属元素(如 P, N, F 等)将被忽略:
- Li3PS4 (含 P, S) -> 识别为 S
- LiFePO4 (含 P, O) -> 识别为 O
- Li3P (仅 P) -> 识别为 Unknown
"""
# --- 修改处:仅保留这四种目标阴离子 ---
valid_anions = {'O', 'S', 'Cl', 'Br'}
# 获取结构中的所有元素符号
elements = set([e.symbol for e in structure.composition.elements])
# 取交集找到当前结构包含的目标阴离子
found_anions = elements.intersection(valid_anions)
if not found_anions:
return "Unknown"
# 如果有多个阴离子,按字母顺序排序并用 '+' 连接
sorted_anions = sorted(list(found_anions))
return "+".join(sorted_anions)
def organize_files_direct(input_folder, output_base):
if not os.path.exists(input_folder):
print(f"输入文件夹不存在: {input_folder}")
return
# 确保输出目录存在
if not os.path.exists(output_base):
os.makedirs(output_base)
cif_files = [f for f in os.listdir(input_folder) if f.endswith(".cif")]
print(f"发现 {len(cif_files)} 个 CIF 文件,开始直接整理...")
count_dict = {}
for filename in cif_files:
file_path = os.path.join(input_folder, filename)
try:
# 读取结构分类
struct = Structure.from_file(file_path)
anion_type = get_anion_type(struct)
# 统计一下分类情况(可选)
count_dict[anion_type] = count_dict.get(anion_type, 0) + 1
# 获取不带后缀的文件名 (ID)
file_base_name = os.path.splitext(filename)[0]
# --- 构建目标路径逻辑 ---
# 目标: ../data/after_step1 / AnionType / ID / ID.cif
if "+" in anion_type:
# 混合阴离子情况 (如 S+O)
# 将文件复制到 S+O 下的各个子阴离子文件夹中 (S+O/S/ID/ID.cif 和 S+O/O/ID/ID.cif)
# 这样既保留了组合关系,又方便后续脚本按元素查找
sub_anions = anion_type.split("+")
for sub in sub_anions:
# 路径: after_step1/S+O/S/123/123.cif
target_folder = os.path.join(output_base, anion_type, sub, file_base_name)
if not os.path.exists(target_folder):
os.makedirs(target_folder)
target_file = os.path.join(target_folder, filename)
shutil.copy(file_path, target_file)
# print(f"整理: {filename} -> {anion_type} (Split)")
else:
# 单一阴离子或 Unknown: after_step1/S/123/123.cif
target_folder = os.path.join(output_base, anion_type, file_base_name)
if not os.path.exists(target_folder):
os.makedirs(target_folder)
target_file = os.path.join(target_folder, filename)
shutil.copy(file_path, target_file)
# print(f"整理: {filename} -> {anion_type}")
except Exception as e:
print(f"处理 {filename} 失败: {e}")
print("整理完成。分类统计:")
for k, v in count_dict.items():
print(f" {k}: {v}")
if __name__ == "__main__":
# 输入路径
input_dir = "../../solidstate-tools/corner-sharing/data/1209/input" # 如果是MP数据请改为 ../data/input_pre
# 输出路径
output_dir = "../data/after_step1"
organize_files_direct(input_dir, output_dir)

129
py/update_tet_occupancy.py Normal file
View File

@@ -0,0 +1,129 @@
import os
import pandas as pd
from tqdm import tqdm
# ================= 配置区域 =================
# CSV 所在的根目录
CSV_ROOT_DIR = "../output"
# CSM .dat 文件所在的根目录
CSM_ROOT_DIR = "../output/CSM"
# ===========================================
def calculate_tet_ratio_from_dat(dat_path):
"""
解析 .dat 文件,计算四面体位 Li 的占比。
返回: float (0.0 - 1.0) 或 None (如果文件不存在或为空)
"""
if not os.path.exists(dat_path):
return None
tet_count = 0
total_count = 0
try:
with open(dat_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
# 简单检查文件是否包含 "No environments found"
if len(lines) > 0 and "No environments found" in lines[0]:
return None
for line in lines:
# 根据截图,每行是一个位点的信息
# 简单字符串匹配,这比 eval 更安全且足够快
if "'type': 'tet'" in line:
tet_count += 1
total_count += 1
elif "'type': 'oct'" in line:
total_count += 1
# 如果还有其他类型,可以在这里加,或者只要是位点行都算进 total
if total_count == 0:
return 0.0
return round(tet_count / total_count, 4)
except Exception as e:
print(f"解析出错 {dat_path}: {e}")
return None
def process_single_csv(csv_path, group_name, anion_name):
"""
读取 CSV -> 寻找对应的 CSM dat 文件 -> 计算比例 -> 更新 CSV
"""
print(f"正在更新 CSV: {csv_path}")
# 读取 CSV确保 ID 是字符串
try:
df = pd.read_csv(csv_path, dtype={'Filename': str})
except Exception as e:
print(f"读取 CSV 失败: {e}")
return
tet_ratios = []
# 遍历 CSV 中的每一行
for index, row in tqdm(df.iterrows(), total=df.shape[0], desc="Updating Occupancy"):
material_id = str(row['Filename']).replace('.0', '')
# 构建对应的 .dat 文件路径
# 路径逻辑: ../output/CSM/Group/Anion/ID.dat
# 注意: 这里的 Group/Anion 结构必须与 analyze_csm.py 生成的一致
if group_name == anion_name:
# 单一阴离子: ../output/CSM/S/123.dat
dat_rel_path = os.path.join(group_name, f"{material_id}.dat")
else:
# 混合阴离子: ../output/CSM/S+O/S/123.dat
dat_rel_path = os.path.join(group_name, anion_name, f"{material_id}.dat")
dat_path = os.path.join(CSM_ROOT_DIR, dat_rel_path)
# 计算比例
ratio = calculate_tet_ratio_from_dat(dat_path)
tet_ratios.append(ratio)
# 添加或更新列
df['Tet_Li_Ratio'] = tet_ratios
# 保存
df.to_csv(csv_path, index=False)
print(f" - 已保存更新后的数据到: {csv_path}")
def run_update():
"""
主程序:遍历 output 目录下的 CSV
"""
if not os.path.exists(CSV_ROOT_DIR):
print(f"CSV 目录不存在: {CSV_ROOT_DIR}")
return
for root, dirs, files in os.walk(CSV_ROOT_DIR):
for file in files:
if file.endswith(".csv"):
csv_path = os.path.join(root, file)
# 解析路径获取 Group 和 Anion
# root: ../output/S --> rel: S
rel_root = os.path.relpath(root, CSV_ROOT_DIR)
path_parts = rel_root.split(os.sep)
if len(path_parts) == 1:
group_name = path_parts[0]
anion_name = path_parts[0]
elif len(path_parts) >= 2:
group_name = path_parts[0]
anion_name = path_parts[1]
else:
continue
# 只有当 CSM 目录里有对应的文件夹时才处理(可选)
process_single_csv(csv_path, group_name, anion_name)
if __name__ == "__main__":
run_update()

356
py/utils/CS_analyse.py Normal file
View File

@@ -0,0 +1,356 @@
from typing import List, Dict
from pymatgen.core.structure import Structure
from pymatgen.analysis.local_env import VoronoiNN
import numpy as np
def check_real(nearest):
real_nearest = []
for site in nearest:
if np.all((site.frac_coords >= 0) & (site.frac_coords <= 1)):
real_nearest.append(site)
return real_nearest
def special_check_for_3(site, nearest):
real_nearest = []
distances = []
for site2 in nearest:
distance = np.linalg.norm(np.array(site.frac_coords) - np.array(site2.frac_coords))
distances.append(distance)
sorted_indices = np.argsort(distances)
for index in sorted_indices[:3]:
real_nearest.append(nearest[index])
return real_nearest
def CS_catulate(
struct,
sp: str = 'Li',
anion: List[str] = ['O'],
tol: float = 0,
cutoff: float = 3.0,
notice: bool = False
) -> Dict[str, Dict[str, int]]:
"""
计算结构中不同类型阳离子多面体之间的共享关系(角、边、面共享)。
该函数会分别计算以下三种情况的共享数量:
1. 目标原子 vs 目标原子 (e.g., Li-Li)
2. 目标原子 vs 其他阳离子 (e.g., Li-X)
3. 其他阳离子 vs 其他阳离子 (e.g., X-Y)
参数:
struct (Structure): 输入的pymatgen结构对象。
sp (str): 目标元素符号,默认为 'Li'
anion (list): 阴离子元素符号列表,默认为 ['O']。
tol (float): VoronoiNN 的容差。对于Li通常设为0。
cutoff (float): VoronoiNN 的截断距离。对于Li通常设为3.0。
notice (bool): 是否打印详细的共享信息。
返回:
dict: 一个字典,包含三类共享关系的统计结果。
"sp_vs_sp", "sp_vs_other", "other_vs_other" 分别对应上述三种情况。
每个键的值是另一个字典统计了共享2个(边)、3个(面)等情况的数量。
例如: {'sp_vs_sp': {'1': 10, '2': 4}, 'sp_vs_other': ...}
共享1个阴离子为角共享2个为边共享3个为面共享。
"""
# 初始化 VoronoiNN 对象
voro_nn = VoronoiNN(tol=tol, cutoff=cutoff)
# 1. 分类存储所有阳离子的近邻阴离子信息
target_sites_info = []
other_cation_sites_info = []
for index, site in enumerate(struct.sites):
# 跳过阴离子本身
if site.species.chemical_system in anion:
continue
# 获取当前位点的近邻阴离子
try:
# 使用 get_nn_info 更直接
nn_info = voro_nn.get_nn_info(struct, index)
nearest_anions = [
nn["site"] for nn in nn_info
if nn["site"].species.chemical_system in anion
]
except Exception as e:
print(f"Warning: Could not get neighbors for site {index} ({site.species_string}): {e}")
continue
if not nearest_anions:
continue
# 整理信息
site_info = {
'index': index,
'element': site.species.chemical_system,
'nearest_anion_indices': {nn.index for nn in nearest_anions}
}
# 根据是否为目标原子进行分类
if site.species.chemical_system == sp:
target_sites_info.append(site_info)
else:
other_cation_sites_info.append(site_info)
# 2. 初始化结果字典
# 共享数量key: 1-角, 2-边, 3-面
results = {
"sp_vs_sp": {"1": 0, "2": 0, "3": 0, "4": 0},
"sp_vs_other": {"1": 0, "2": 0, "3": 0, "4": 0},
"other_vs_other": {"1": 0, "2": 0, "3": 0, "4": 0},
}
# 3. 计算不同类别之间的共享关系
# 3.1 目标原子 vs 目标原子 (sp_vs_sp)
for i in range(len(target_sites_info)):
for j in range(i + 1, len(target_sites_info)):
atom_i = target_sites_info[i]
atom_j = target_sites_info[j]
shared_anions = atom_i['nearest_anion_indices'].intersection(atom_j['nearest_anion_indices'])
shared_count = len(shared_anions)
if shared_count > 0 and str(shared_count) in results["sp_vs_sp"]:
results["sp_vs_sp"][str(shared_count)] += 1
if notice:
print(
f"[Li-Li] Atom {atom_i['index']} and {atom_j['index']} share {shared_count} anions: {shared_anions}")
# 3.2 目标原子 vs 其他阳离子 (sp_vs_other)
for atom_sp in target_sites_info:
for atom_other in other_cation_sites_info:
shared_anions = atom_sp['nearest_anion_indices'].intersection(atom_other['nearest_anion_indices'])
shared_count = len(shared_anions)
if shared_count > 0 and str(shared_count) in results["sp_vs_other"]:
results["sp_vs_other"][str(shared_count)] += 1
if notice:
print(
f"[Li-Other] Atom {atom_sp['index']} and {atom_other['index']} share {shared_count} anions: {shared_anions}")
# 3.3 其他阳离子 vs 其他阳离子 (other_vs_other)
for i in range(len(other_cation_sites_info)):
for j in range(i + 1, len(other_cation_sites_info)):
atom_i = other_cation_sites_info[i]
atom_j = other_cation_sites_info[j]
shared_anions = atom_i['nearest_anion_indices'].intersection(atom_j['nearest_anion_indices'])
shared_count = len(shared_anions)
if shared_count > 0 and str(shared_count) in results["other_vs_other"]:
results["other_vs_other"][str(shared_count)] += 1
if notice:
print(
f"[Other-Other] Atom {atom_i['index']} and {atom_j['index']} share {shared_count} anions: {shared_anions}")
return results
def CS_catulate_old(struct, sp='Li', anion=['O'], tol=0, cutoff=3.0,notice=False,ID=None):
"""
计算结构中目标元素与最近阴离子的共享关系。
参数:
struct (Structure): 输入结构。
sp (str): 目标元素符号,默认为 'Li'
anion (list): 阴离子列表,默认为 ['O']。
tol (float): VoronoiNN 的容差,默认为 0。
cutoff (float): VoronoiNN 的截断距离,默认为 3.0。
返回:
list: 包含每个目标位点及其最近阴离子索引的列表。
"""
# 初始化 VoronoiNN 对象
if sp=='Li':
tol = 0
cutoff = 3.0
voro_nn = VoronoiNN(tol=tol, cutoff=cutoff)
# 初始化字典,用于统计共享关系
shared_count = {"2": 0, "3": 0,"4":0,"5":0,"6":0}
# 存储结果的列表
atom_dice = []
# 遍历结构中的每个位点
for index,site in enumerate(struct.sites):
# 跳过阴离子位点
if site.species.chemical_system in anion:
continue
# 跳过Li原子
if site.species.chemical_system == sp:
continue
# 获取 Voronoi 多面体信息
voro_info = voro_nn.get_voronoi_polyhedra(struct, index)
# 找到最近的阴离子位点
nearest_anions = [
nn_info["site"] for nn_info in voro_info.values()
if nn_info["site"].species.chemical_system in anion
]
# 如果没有找到最近的阴离子,跳过
if not nearest_anions:
print(f"No nearest anions found for {ID} site {index}.")
continue
if site.species.chemical_system == 'B' or site.species.chemical_system == 'N':
nearest_anions = special_check_for_3(site,nearest_anions)
nearest_anions = check_real(nearest_anions)
# 将结果添加到 atom_dice 列表中
atom_dice.append({
'index': index,
'nearest_index': [nn.index for nn in nearest_anions]
})
# 枚举 atom_dice 中的所有原子对
for i, atom_i in enumerate(atom_dice):
for j, atom_j in enumerate(atom_dice[i + 1:], start=i + 1):
# 获取两个原子的最近阴离子索引
nearest_i = set(atom_i['nearest_index'])
nearest_j = set(atom_j['nearest_index'])
# 比较最近阴离子的交集大小
shared_count_key = str(len(nearest_i & nearest_j))
# 更新字典中的计数
if shared_count_key in shared_count:
shared_count[shared_count_key] += 1
if notice:
if shared_count_key=='2':
print(f"{atom_j['index']}{atom_i['index']}之间存在共线")
print(f"共线的阴离子为{nearest_i & nearest_j}")
if shared_count_key=='3':
print(f"{atom_j['index']}{atom_i['index']}之间存在共面")
print(f"共面的阴离子为{nearest_i & nearest_j}")
# # 最后将字典中的值除以 2因为每个共享关系被计算了两次
# for key in shared_count.keys():
# shared_count[key] //= 2
return shared_count
def CS_count(struct, sharing_results: Dict[str, Dict[str, int]], sp: str = 'Li') -> float:
"""
分析多面体共享结果,计算平均每个目标原子参与的共享阴离子数。
这个函数是 calculate_polyhedra_sharing 的配套函数。
参数:
struct (Structure): 输入的pymatgen结构对象用于统计目标原子总数。
sharing_results (dict): 来自 calculate_polyhedra_sharing 函数的输出结果。
sp (str): 目标元素符号,默认为 'Li'
返回:
float: 平均每个目标原子sp参与的共享阴离子数量。
例如结果为2.5意味着平均每个Li原子通过共享与其他阳离子
包括Li和其他阳离子连接了2.5个阴离子。
"""
# 1. 统计结构中目标原子的总数
target_atom_count = 0
for site in struct.sites:
if site.species.chemical_system == sp:
target_atom_count += 1
# 如果结构中没有目标原子直接返回0避免除以零错误
if target_atom_count == 0:
return 0.0
# 2. 计算加权的共享阴离子总数
total_shared_anions = 0
# 处理 sp_vs_sp (例如 Li-Li) 的共享
# 每个共享关系涉及两个目标原子,所以权重需要乘以 2
if "sp_vs_sp" in sharing_results:
sp_vs_sp_counts = sharing_results["sp_vs_sp"]
for num_shared_str, count in sp_vs_sp_counts.items():
num_shared = int(num_shared_str)
# 权重 = 共享阴离子数 * 涉及的目标原子数 (2) * 出现次数
total_shared_anions += num_shared * 2 * count
# 处理 sp_vs_other (例如 Li-X) 的共享
# 每个共享关系涉及一个目标原子,所以权重乘以 1
if "sp_vs_other" in sharing_results:
sp_vs_other_counts = sharing_results["sp_vs_other"]
for num_shared_str, count in sp_vs_other_counts.items():
num_shared = int(num_shared_str)
# 权重 = 共享阴离子数 * 涉及的目标原子数 (1) * 出现次数
total_shared_anions += num_shared * 1 * count
# 3. 计算平均值
# 平均每个目标原子参与的共享阴离子数 = 总的加权共享数 / 目标原子总数
average_sharing_per_atom = total_shared_anions / target_atom_count
return average_sharing_per_atom
def CS_count_old(struct, shared_count, sp='Li'):
count = 0
for site in struct.sites:
if site.species.chemical_system == sp:
count += 1 # 累加符合条件的原子数量
CS_count = 0
for i in range(2, 7): # 遍历范围 [2, 3, 4, 5]
if str(i) in shared_count: # 检查键是否存在
CS_count += shared_count[str(i)] * i # 累加计算结果
if count > 0: # 防止除以零
CS_count /= count # 平均化结果
else:
CS_count = 0 # 如果 count 为 0直接返回 0
return CS_count
def check_only_corner_sharing(sharing_results: Dict[str, Dict[str, int]]) -> int:
"""
检查目标原子(sp)是否只参与了角共享共享1个阴离子
该函数是 calculate_polyhedra_sharing 的配套函数。
参数:
sharing_results (dict): 来自 calculate_polyhedra_sharing 函数的输出结果。
返回:
int:
- 1: 如果 sp 的共享关系中,边共享(2)、面共享(3)等数量均为0
并且至少存在一个角共享(1)。
- 0: 如果 sp 存在任何边、面等共享,或者没有任何共享关系。
"""
# 提取与目标原子 sp 相关的共享数据
sp_vs_sp_counts = sharing_results.get("sp_vs_sp", {})
sp_vs_other_counts = sharing_results.get("sp_vs_other", {})
# 1. 检查是否存在任何边共享、面共享等 (共享数 > 1)
# 检查 sp-sp 的共享
for num_shared_str, count in sp_vs_sp_counts.items():
if int(num_shared_str) > 1 and count > 0:
return 0 # 发现了边/面共享,立即返回 0
# 检查 sp-other 的共享
for num_shared_str, count in sp_vs_other_counts.items():
if int(num_shared_str) > 1 and count > 0:
return 0 # 发现了边/面共享,立即返回 0
# 2. 检查是否存在至少一个角共享 (共享数 == 1)
# 运行到这里,说明已经没有任何边/面共享了。
# 现在需要确认是否真的存在角共享,而不是完全没有共享。
corner_share_sp_sp = sp_vs_sp_counts.get("1", 0) > 0
corner_share_sp_other = sp_vs_other_counts.get("1", 0) > 0
if corner_share_sp_sp or corner_share_sp_other:
return 1 # 确认只存在角共享
else:
return 0 # 没有任何共享关系,也返回 0
# structure = Structure.from_file("../raw/0921/wjy_001.cif")
# a = CS_catulate(structure,notice=True)
# b = CS_count(structure,a)
# print(f"{a}\n{b}")
# print(check_only_corner_sharing(a))

0
py/utils/__init__.py Normal file
View File

210
py/utils/analyze_env_st.py Normal file
View File

@@ -0,0 +1,210 @@
#!/usr/bin/env python
# This code extracts the lithium environment of all of lithium sites provided in a structure file.
import os, sys
import numpy as np
import scipy
import argparse
from scipy.spatial import ConvexHull
from itertools import permutations
from pymatgen.core.structure import Structure
from pymatgen.core.periodic_table import *
from pymatgen.core.composition import *
from pymatgen.ext.matproj import MPRester
from pymatgen.io.vasp.outputs import *
from pymatgen.analysis.chemenv.coordination_environments.coordination_geometry_finder import LocalGeometryFinder
from pymatgen.analysis.chemenv.coordination_environments.structure_environments import LightStructureEnvironments
from pymatgen.analysis.chemenv.coordination_environments.chemenv_strategies import SimplestChemenvStrategy
from pymatgen.analysis.chemenv.coordination_environments.coordination_geometries import *
__author__ = "KyuJung Jun"
__version__ = "0.1"
__maintainer__ = "KyuJung Jun"
__email__ = "kjun@berkeley.edu"
__status__ = "Development"
'''
Input for the script : path to the structure file supported by Pymatgen
Structures with partial occupancy should be ordered or modified to full occupancy by Pymatgen.
'''
parser = argparse.ArgumentParser()
parser.add_argument('structure', help='path to the structure file supported by Pymatgen', nargs='?')
parser.add_argument('envtype', help='both, tet, oct, choosing which perfect environment to reference to', nargs='?')
args = parser.parse_args()
class HiddenPrints:
'''
class to reduce the output lines
'''
def __enter__(self):
self._original_stdout = sys.stdout
sys.stdout = open(os.devnull, 'w')
def __exit__(self, exc_type, exc_val, exc_tb):
sys.stdout.close()
sys.stdout = self._original_stdout
def non_elements(struct, sp='Li'):
"""
struct : 必须是一个有序结构
sp : the mobile specie
returns a new structure containing only the framework anions (O, S, N).
"""
anions_to_keep = {"O", "S", "N","Br","Cl"}
stripped = struct.copy()
species_to_remove = [el.symbol for el in stripped.composition.elements
if el.symbol not in anions_to_keep]
if species_to_remove:
stripped.remove_species(species_to_remove)
return stripped
def site_env(coord, struct, sp="Li", envtype='both'):
'''
coord : Fractional coordinate of the target atom
struct : structure object from Pymatgen
sp : the mobile specie
envtype : This sets the reference perfect structure. 'both' compares CSM_tet and CSM_oct and assigns to the lower one.
'tet' refers to the perfect tetrahedron and 'oct' refers to the perfect octahedron
result : a dictionary of environment information
'''
stripped = non_elements(struct)
with_li = stripped.copy()
with_li.append(sp, coord, coords_are_cartesian=False, validate_proximity=False)
with_li = with_li.get_sorted_structure()
tet_oct_competition = []
if envtype == 'both' or envtype == 'tet':
for dist in np.linspace(1, 4, 601):
neigh = with_li.get_neighbors(with_li.sites[0], dist)
if len(neigh) < 4:
continue
elif len(neigh) > 4:
break
neigh_coords = [i.coords for i in neigh]
with HiddenPrints():
lgf = LocalGeometryFinder(only_symbols=["T:4"])
lgf.setup_structure(structure=with_li)
lgf.setup_local_geometry(isite=0, coords=neigh_coords)
try:
site_volume = ConvexHull(neigh_coords).volume
tet_env_list = []
for i in range(20):
tet_env = {'csm': lgf.get_coordination_symmetry_measures()['T:4']['csm'], 'vol': site_volume,
'type': 'tet'}
tet_env_list.append(tet_env)
tet_env = min(tet_env_list, key=lambda x: x['csm'])
tet_oct_competition.append(tet_env)
except Exception as e:
print(e)
print("This site cannot be recognized as tetrahedral site")
if len(neigh) == 4:
break
if envtype == 'both' or envtype == 'oct':
for dist in np.linspace(1, 4, 601):
neigh = with_li.get_neighbors(with_li.sites[0], dist)
if len(neigh) < 6:
continue
elif len(neigh) > 6:
break
neigh_coords = [i.coords for i in neigh]
with HiddenPrints():
lgf = LocalGeometryFinder(only_symbols=["O:6"], permutations_safe_override=False)
lgf.setup_structure(structure=with_li)
lgf.setup_local_geometry(isite=0, coords=neigh_coords)
try:
site_volume = ConvexHull(neigh_coords).volume
oct_env_list = []
for i in range(20):
'''
20 times sampled in case of the algorithm "APPROXIMATE_FALLBACK" is used. Large number of permutations
are performed, but the default value in the function "coordination_geometry_symmetry_measures_fallback_random"
(NRANDOM=10) is often too small. This is not a problem if algorithm of "SEPARATION_PLANE" is used.
'''
oct_env = {'csm': lgf.get_coordination_symmetry_measures()['O:6']['csm'], 'vol': site_volume,
'type': 'oct'}
oct_env_list.append(oct_env)
oct_env = min(oct_env_list, key=lambda x: x['csm'])
tet_oct_competition.append(oct_env)
except Exception as e:
print(e)
print("This site cannot be recognized as octahedral site")
if len(neigh) == 6:
break
if len(tet_oct_competition) == 0:
return {'csm': np.nan, 'vol': np.nan, 'type': 'Non_' + envtype}
elif len(tet_oct_competition) == 1:
return tet_oct_competition[0]
elif len(tet_oct_competition) == 2:
csm1 = tet_oct_competition[0]
csm2 = tet_oct_competition[1]
if csm1['csm'] > csm2['csm']:
return csm2
else:
return csm1
def extract_sites(struct, sp="Li", envtype='both'):
"""
struct : structure object from Pymatgen
envtype : 'tet', 'oct', or 'both'
sp : target element to analyze environment
"""
envlist = []
# --- 关键修改:直接遍历原始结构,即使它是无序的 ---
# 我们不再调用 get_sorted_structure()
# 我们只关心那些含有目标元素 sp 的位点
# 遍历每一个位点 (site)
for i, site in enumerate(struct):
# 检查当前位点的组分(site.species)中是否包含我们感兴趣的元素(sp)
# site.species.elements 返回该位点上的元素列表,例如 [Element Li, Element Fe]
# [el.symbol for el in site.species.elements] 将其转换为符号列表 ['Li', 'Fe']
site_elements = [el.symbol for el in site.species.elements]
if sp in site_elements:
# 如果找到了Li我们就对这个位点进行环境分析
# 注意:我们将原始的、可能无序的 struct 传递给 site_env
# 因为 site_env 内部的函数 (如 LocalGeometryFinder) 知道如何处理它
# 为了让下游函数(特别是 non_elements能够工作
# 我们在这里创建一个一次性的、临时的有序结构副本给它
# 这可以避免我们之前遇到的所有 'ordered structures only' 错误
temp_ordered_struct = struct.get_sorted_structure()
singleenv = site_env(site.frac_coords, temp_ordered_struct, sp, envtype)
envlist.append({'frac_coords': site.frac_coords, 'type': singleenv['type'], 'csm': singleenv['csm'],
'volume': singleenv['vol']})
if not envlist:
print(f"警告: 在结构中未找到元素 {sp} 的占位。")
return envlist
def export_envs(envlist, sp='Li', envtype='both', fname=None):
'''
envlist : list of dictionaries of environment information
fname : Output file name
'''
if not fname:
fname = "extracted_environment_info" + "_" + sp + "_" + envtype + ".dat"
with open(fname, 'w') as f:
f.write('List of environment information\n')
f.write('Species : ' + sp + "\n")
f.write('Envtype : ' + envtype + "\n")
for index, i in enumerate(envlist):
f.write("Site index " + str(index) + ": " + str(i) + '\n')
# struct = Structure.from_file("../raw/0921/wjy_475.cif")
# site_info = extract_sites(struct, envtype="both")
# export_envs(site_info, sp="Li", envtype="both")

347
readme.md
View File

@@ -1,4 +1,4 @@
# 高通量筛选与扩胞项目
# 高通量筛选与扩胞项目 v2.3
## 环境配置需求
@@ -12,9 +12,29 @@
### 2. screen 环境 (用于逻辑筛选与数据处理)
* **Python**: 3.11.4
* **核心库**: `pymatgen==2024.11.13`, `pandas` (新增用于处理CSV)
* **路径**: `/cluster/home/koko125/anaconda3/envs/screen`
## 快速开始
### 方式一:使用新版主程序(推荐)
```bash
# 激活 screen 环境
conda activate /cluster/home/koko125/anaconda3/envs/screen
# 运行主程序
python main.py
```
主程序提供交互式界面,支持:
- 数据库分析与筛选
- 本地多进程并行
- SLURM 直接提交(无需生成脚本文件)
- 实时进度条显示
- 扩胞处理与化合价添加
### 方式二:传统方式
1. **数据准备**:
* 如果数据来源为 **Materials Project (MP)**,请将 CIF 文件放入 `data/input_pre`
* 如果数据来源为 **ICSD**,请直接将 CIF 文件放入 `data/input`
@@ -25,6 +45,52 @@
bash main.sh
```
## 新版功能特性 (v2.1)
### 执行模式
1. **本地多进程模式** (`local`)
- 使用 Python multiprocessing 在本地并行执行
- 适合小规模任务或测试
2. **SLURM 直接提交模式** (`slurm`)
- 直接在 Python 中提交 SLURM 作业
- 无需生成脚本文件
- 实时监控作业状态和进度
- 适合大规模高通量计算
### 进度显示
```
分析CIF文件: |████████████████████░░░░░░░░░░░░░░░░░░░░| 500/1000 (50.0%) [0:05:23<0:05:20, 1.6it/s] ✓480 ✗20
```
### 输出格式
处理后的文件支持两种输出格式:
1. **原始格式**(平铺)
```
workspace/processed/
├── 1514027.cif
├── 1514072.cif
└── ...
```
2. **分析格式**(按阴离子分类)
```
workspace/data/
├── O/
│ ├── 1514027/
│ │ └── 1514027.cif
│ └── 1514072/
│ └── 1514072.cif
├── S/
│ └── ...
└── Cl+O/
└── ...
```
## 处理流程详解
### Stage 1: 预处理与基础筛选 (Step 1)
@@ -51,9 +117,12 @@
---
## 扩胞逻辑 (Step 5 - 待后续执行)
## 扩胞逻辑 (Step 5)
目前扩胞逻辑维持原状,基于筛选后的结构进行处理。
扩胞处理已集成到新版主程序中,支持:
- 自动计算扩胞因子
- 可选保存数量当为1时不加后缀
- 自动添加化合价信息
### 算法分解
1. **读取结构**: 解析 CIF 文件。
@@ -72,3 +141,275 @@
### 假设条件
* 只考虑两个原子在同一位置上的共占位情况。
* 不考虑 Li 原子的共占位情况,对 Li 原子不做处理。
## 项目结构
```
screen/
├── main.py # 主入口(新版)
├── main.sh # 传统脚本入口
├── readme.md # 本文档
├── config/ # 配置文件
│ ├── settings.yaml
│ └── valence_states.yaml
├── src/ # 源代码
│ ├── analysis/ # 分析模块
│ │ ├── database_analyzer.py
│ │ ├── report_generator.py
│ │ ├── structure_inspector.py
│ │ └── worker.py
│ ├── core/ # 核心模块
│ │ ├── executor.py # 任务执行器(新)
│ │ ├── scheduler.py # 调度器
│ │ └── progress.py # 进度管理
│ ├── preprocessing/ # 预处理模块
│ │ ├── processor.py # 结构处理器
│ │ └── ...
│ └── utils/ # 工具函数
├── py/ # 传统脚本
├── tool/ # 工具和配置
│ ├── analyze_voronoi_nodes.py
│ └── Li/ # 化合价配置
└── workspace/ # 工作区
├── data/ # 分析格式输出
└── processed/ # 原始格式输出
```
## API 使用示例
### 使用执行器
```python
from src.core.executor import create_executor, TaskExecutor
# 创建执行器
executor = create_executor(
mode="slurm", # 或 "local"
max_workers=32,
conda_env="/cluster/home/koko125/anaconda3/envs/screen"
)
# 定义任务
tasks = [
(file_path, "Li", {"O", "S"})
for file_path in cif_files
]
# 执行
from src.analysis.worker import analyze_single_file
results = executor.run(tasks, analyze_single_file, desc="分析CIF文件")
```
### 使用数据库分析器
```python
from src.analysis.database_analyzer import DatabaseAnalyzer
analyzer = DatabaseAnalyzer(
database_path="/path/to/cif/files",
target_cation="Li",
target_anions={"O", "S", "Cl", "Br"},
anion_mode="all"
)
report = analyzer.analyze(show_progress=True)
report.save("analysis_report.json")
```
### 使用 Zeo++ 执行器
```python
from src.computation.workspace_manager import WorkspaceManager
from src.computation.zeo_executor import ZeoExecutor, ZeoConfig
# 设置工作区
ws_manager = WorkspaceManager(
workspace_path="workspace",
tool_dir="tool",
target_cation="Li"
)
# 创建软链接
workspace_info = ws_manager.setup_workspace()
# 获取计算任务
tasks = ws_manager.get_computation_tasks(workspace_info)
# 配置 Zeo++ 执行器
config = ZeoConfig(
conda_env="/cluster/home/koko125/anaconda3/envs/zeo",
partition="cpu",
max_concurrent=50,
time_limit="2:00:00"
)
# 执行计算
executor = ZeoExecutor(config)
results = executor.run_batch(tasks, output_dir="slurm_logs")
executor.print_results_summary(results)
```
## 新版功能特性 (v2.2)
### Zeo++ Voronoi 分析
新增 SLURM 作业数组支持,可高效调度大量 Zeo++ 计算任务:
1. **自动工作区管理**
- 检测现有工作区数据
- 自动创建配置文件软链接
- 按阴离子类型组织目录结构
2. **SLURM 作业数组**
- 使用 `--array` 参数批量提交任务
- 支持最大并发数限制(如 `%50`
- 自动分批处理超大任务集
3. **实时进度监控**
- 通过状态文件跟踪任务完成情况
- 支持 Ctrl+C 中断监控(作业继续运行)
- 自动收集输出文件
### 工作流程
```
Step 1: 数据库分析
Step 1.5: 扩胞处理 + 化合价添加
Step 2-4: Zeo++ Voronoi 分析
├── 创建软链接 (yaml 配置 + 计算脚本)
├── 提交 SLURM 作业数组
└── 监控进度并收集结果
Step 5: 结果处理与筛选
├── 从 log.txt 提取关键参数
├── 汇总到 CSV 文件
├── 应用筛选条件
└── 复制通过筛选的结构到 passed/ 目录
```
### 筛选条件
Step 5 支持以下筛选条件:
- **最小渗透直径** (Percolation Diameter): 默认 1.0 Å
- **最小 d 值** (Minimum of d): 默认 2.0
- **最大节点长度** (Maximum Node Length): 默认不限制
### 日志输出
Zeo++ 计算的输出会重定向到每个结构目录下的 `log.txt` 文件:
```bash
python analyze_voronoi_nodes.py *.cif -i O.yaml > log.txt 2>&1
```
日志中包含的关键信息:
- `Percolation diameter (A): X.XX` - 渗透直径
- `the minium of d\nX.XX` - 最小 d 值
- `Maximum node length detected: X.XX A` - 最大节点长度
### 目录结构
```
workspace/
├── data/ # 分析格式数据
│ ├── O/ # 氧化物
│ │ ├── O.yaml -> tool/Li/O.yaml
│ │ ├── analyze_voronoi_nodes.py -> tool/analyze_voronoi_nodes.py
│ │ ├── 1514027/
│ │ │ ├── 1514027.cif
│ │ │ ├── 1514027_all_accessed_node.cif # Zeo++ 输出
│ │ │ ├── 1514027_bond_valence_filtered.cif
│ │ │ └── 1514027_bv_info.csv
│ │ └── ...
│ ├── S/ # 硫化物
│ └── Cl+O/ # 复合阴离子
├── processed/ # 原始格式数据
├── slurm_logs/ # SLURM 日志
│ ├── tasks.json
│ ├── submit_array.sh
│ ├── task_0.out
│ ├── task_0.err
│ ├── status_0.txt
│ └── ...
└── zeo_results.json # 计算结果汇总
```
### 配置文件说明
`tool/Li/O.yaml` 示例:
```yaml
SPECIE: Li+
ANION: O
PERCO_R: 0.5
NEIGHBOR: 1.8
LONG: 2.2
```
参数说明:
- `SPECIE`: 目标扩散离子(带电荷)
- `ANION`: 阴离子类型
- `PERCO_R`: 渗透半径阈值
- `NEIGHBOR`: 邻近距离阈值
- `LONG`: 长节点判定阈值
## 新版功能特性 (v2.3)
### 断点续做功能
v2.3 新增智能断点续做功能,支持从任意步骤继续执行:
1. **自动状态检测**
- 启动时自动检测工作流程状态
- 检测 `workspace/data/` 是否存在且有结构 → 判断 Step 1 是否完成
- 检测结构目录下是否有 `log.txt` → 判断 Zeo++ 计算是否完成
- 如果 50% 以上结构有 log.txt认为 Zeo++ 计算已完成
2. **智能流程跳转**
- 如果已完成 Zeo++ 计算 → 可直接进行筛选
- 如果已完成扩胞处理 → 可直接进行 Zeo++ 计算
- 从后往前检测,自动跳过已完成的步骤
3. **分步执行与中断**
- 每个大步骤完成后询问是否继续
- 支持中途退出,下次运行时自动检测进度
- 三大步骤:扩胞与加化合价 → Zeo++ 计算 → 筛选
### 工作流程状态示例
```
工作流程状态检测
检测到现有工作区数据:
- 结构总数: 1234
- 已完成 Zeo++ 计算: 1200 (97.2%)
- 未完成 Zeo++ 计算: 34
当前状态: Zeo++ 计算已完成
可选操作:
[1] 直接进行结果筛选 (Step 5)
[2] 重新运行 Zeo++ 计算 (Step 2-4)
[3] 从头开始 (Step 1)
[0] 退出
请选择 [1]:
```
### 使用场景
1. **首次运行**
- 从 Step 1 开始完整执行
- 每步完成后可选择继续或退出
2. **中断后继续**
- 自动检测已完成的步骤
- 提供从当前进度继续的选项
3. **重新筛选**
- 修改筛选条件后
- 可直接运行 Step 5 而无需重新计算
4. **部分重算**
- 如需重新计算部分结构
- 可选择重新运行 Zeo++ 计算

12
src/__init__.py Normal file
View File

@@ -0,0 +1,12 @@
"""
高通量筛选与扩胞项目 - 源代码包
"""
from . import analysis
from . import core
from . import preprocessing
from . import computation
from . import utils
__version__ = "2.2.0"
__all__ = ['analysis', 'core', 'preprocessing', 'computation', 'utils']

0
src/analysis/__init__.py Normal file
View File

View File

@@ -0,0 +1,468 @@
"""
数据库分析器:支持高性能并行分析
"""
import os
import pickle
import json
from dataclasses import dataclass, field, asdict
from typing import Dict, List, Set, Optional
from pathlib import Path
from .structure_inspector import StructureInspector, StructureInfo
from .worker import analyze_single_file
from ..core.scheduler import ParallelScheduler, ResourceConfig
# 在 DatabaseReport 类中添加缺失的字段
@dataclass
class DatabaseReport:
"""数据库分析报告"""
# 基础统计
database_path: str = ""
total_files: int = 0
valid_files: int = 0
invalid_files: int = 0
# 目标元素统计
target_cation: str = ""
target_anions: Set[str] = field(default_factory=set)
anion_mode: str = ""
# 含目标阳离子的统计
cation_containing_count: int = 0
cation_containing_ratio: float = 0.0
# 阴离子分布
anion_distribution: Dict[str, int] = field(default_factory=dict)
anion_ratios: Dict[str, float] = field(default_factory=dict)
single_anion_count: int = 0
mixed_anion_count: int = 0
# 数据质量统计
with_oxidation_states: int = 0
without_oxidation_states: int = 0
needs_expansion_count: int = 0
cation_with_vacancy_count: int = 0 # Li与空位共占位新增
cation_with_other_cation_count: int = 0 # Li与其他阳离子共占位新增
anion_partial_occupancy_count: int = 0
binary_compound_count: int = 0
has_water_count: int = 0
has_radioactive_count: int = 0
# 可处理性统计
directly_processable: int = 0
needs_preprocessing: int = 0
cannot_process: int = 0
# 详细信息
all_structures: List[StructureInfo] = field(default_factory=list)
skip_reasons_summary: Dict[str, int] = field(default_factory=dict)
# 扩胞相关统计(新增)
expansion_stats: Dict[str, int] = field(default_factory=lambda: {
'no_expansion_needed': 0,
'expansion_factor_2': 0,
'expansion_factor_3': 0,
'expansion_factor_4_8': 0,
'expansion_factor_large': 0,
'cannot_expand': 0,
})
expansion_factor_distribution: Dict[int, int] = field(default_factory=dict)
def to_dict(self) -> dict:
"""转换为可序列化的字典"""
from dataclasses import fields as dataclass_fields
def convert_value(val):
"""递归转换值为可序列化类型"""
if isinstance(val, set):
return list(val)
elif isinstance(val, dict):
return {k: convert_value(v) for k, v in val.items()}
elif isinstance(val, list):
return [convert_value(item) for item in val]
elif hasattr(val, '__dataclass_fields__'):
# 处理 dataclass 对象
return {k: convert_value(v) for k, v in asdict(val).items()}
else:
return val
result = {}
for f in dataclass_fields(self):
value = getattr(self, f.name)
result[f.name] = convert_value(value)
return result
def save(self, path: str):
"""保存报告到JSON文件"""
with open(path, 'w', encoding='utf-8') as f:
json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
print(f"✅ 报告已保存到: {path}")
@classmethod
def load(cls, path: str) -> 'DatabaseReport':
"""从JSON文件加载报告"""
with open(path, 'r', encoding='utf-8') as f:
d = json.load(f)
# 处理 set 类型
if 'target_anions' in d:
d['target_anions'] = set(d['target_anions'])
# 处理 StructureInfo 列表(简化处理,不恢复完整对象)
if 'all_structures' in d:
d['all_structures'] = []
return cls(**d)
def get_processable_files(self, include_needs_expansion: bool = True) -> List[StructureInfo]:
"""
获取可处理的文件列表
Args:
include_needs_expansion: 是否包含需要扩胞的文件
Returns:
可处理的 StructureInfo 列表
"""
result = []
for info in self.all_structures:
if info is None or not info.is_valid:
continue
if not info.contains_target_cation:
continue
if not info.can_process:
continue
if not include_needs_expansion and info.needs_expansion:
continue
result.append(info)
return result
def copy_processable_files(
self,
output_dir: str,
include_needs_expansion: bool = True,
organize_by_anion: bool = True
) -> Dict[str, int]:
"""
将可处理的CIF文件复制到工作区
Args:
output_dir: 输出目录(如 workspace/data
include_needs_expansion: 是否包含需要扩胞的文件
organize_by_anion: 是否按阴离子类型组织子目录
Returns:
复制统计信息 {类别: 数量}
"""
import shutil
# 创建输出目录
os.makedirs(output_dir, exist_ok=True)
# 获取可处理文件
processable = self.get_processable_files(include_needs_expansion)
stats = {
'direct': 0, # 可直接处理
'needs_expansion': 0, # 需要扩胞
'total': 0
}
# 按类型创建子目录
if organize_by_anion:
anion_dirs = {}
for info in processable:
# 确定目标目录
if organize_by_anion and info.anion_types:
# 使用主要阴离子作为目录名
anion_key = '+'.join(sorted(info.anion_types))
if anion_key not in anion_dirs:
anion_dir = os.path.join(output_dir, anion_key)
os.makedirs(anion_dir, exist_ok=True)
anion_dirs[anion_key] = anion_dir
target_dir = anion_dirs[anion_key]
else:
target_dir = output_dir
# 进一步按处理类型分类
if info.needs_expansion:
sub_dir = os.path.join(target_dir, 'needs_expansion')
stats['needs_expansion'] += 1
else:
sub_dir = os.path.join(target_dir, 'direct')
stats['direct'] += 1
os.makedirs(sub_dir, exist_ok=True)
# 复制文件
src_path = info.file_path
dst_path = os.path.join(sub_dir, info.file_name)
try:
shutil.copy2(src_path, dst_path)
stats['total'] += 1
except Exception as e:
print(f"⚠️ 复制失败 {info.file_name}: {e}")
# 打印统计
print(f"\n📁 文件已复制到: {output_dir}")
print(f" 可直接处理: {stats['direct']}")
print(f" 需要扩胞: {stats['needs_expansion']}")
print(f" 总计: {stats['total']}")
return stats
class DatabaseAnalyzer:
"""数据库分析器 - 支持高性能并行"""
def __init__(
self,
database_path: str,
target_cation: str = "Li",
target_anions: Set[str] = None,
anion_mode: str = "all",
max_cores: int = 4,
task_complexity: str = "medium"
):
"""
初始化分析器
Args:
database_path: 数据库路径
target_cation: 目标阳离子
target_anions: 目标阴离子集合
anion_mode: 阴离子模式
max_cores: 最大可用核数
task_complexity: 任务复杂度 ('low', 'medium', 'high')
"""
self.database_path = database_path
self.target_cation = target_cation
self.target_anions = target_anions or {'O', 'S', 'Cl', 'Br'}
self.anion_mode = anion_mode
self.max_cores = max_cores
self.task_complexity = task_complexity
# 获取文件列表
self.cif_files = self._get_cif_files()
# 配置调度器
self.resource_config = ParallelScheduler.recommend_config(
num_tasks=len(self.cif_files),
task_complexity=task_complexity,
max_cores=max_cores
)
self.scheduler = ParallelScheduler(self.resource_config)
def _get_cif_files(self) -> List[str]:
"""获取所有CIF文件路径"""
cif_files = []
if os.path.isfile(self.database_path):
if self.database_path.endswith('.cif'):
cif_files.append(self.database_path)
else:
for root, dirs, files in os.walk(self.database_path):
for f in files:
if f.endswith('.cif'):
cif_files.append(os.path.join(root, f))
return sorted(cif_files)
def analyze(self, show_progress: bool = True) -> DatabaseReport:
"""
执行并行分析
Args:
show_progress: 是否显示进度
Returns:
DatabaseReport: 分析报告
"""
report = DatabaseReport(
database_path=self.database_path,
target_cation=self.target_cation,
target_anions=self.target_anions,
anion_mode=self.anion_mode,
total_files=len(self.cif_files)
)
if report.total_files == 0:
print(f"⚠️ 警告: 在 {self.database_path} 中未找到CIF文件")
return report
# 准备任务
tasks = [
(f, self.target_cation, self.target_anions)
for f in self.cif_files
]
# 执行并行分析
results = self.scheduler.run_local(
tasks=tasks,
worker_func=analyze_single_file,
desc="分析CIF文件"
)
# 过滤有效结果
report.all_structures = [r for r in results if r is not None]
# 统计结果
self._compute_statistics(report)
return report
def analyze_slurm(
self,
output_dir: str,
job_name: str = "cif_analysis"
) -> str:
"""
提交SLURM作业进行分析
Args:
output_dir: 输出目录
job_name: 作业名称
Returns:
作业ID
"""
os.makedirs(output_dir, exist_ok=True)
# 保存任务配置
tasks_file = os.path.join(output_dir, "tasks.json")
with open(tasks_file, 'w') as f:
json.dump({
'files': self.cif_files,
'target_cation': self.target_cation,
'target_anions': list(self.target_anions),
'anion_mode': self.anion_mode
}, f)
# 生成SLURM脚本
worker_script = os.path.join(
os.path.dirname(__file__), 'worker.py'
)
script = self.scheduler.generate_slurm_script(
tasks_file=tasks_file,
worker_script=worker_script,
output_dir=output_dir,
job_name=job_name
)
# 保存并提交
script_path = os.path.join(output_dir, "submit.sh")
return self.scheduler.submit_slurm_job(script, script_path)
# 更新 _compute_statistics 方法
def _compute_statistics(self, report: DatabaseReport):
"""计算统计数据"""
for info in report.all_structures:
# 确保 info 不是 None
if info is None:
report.invalid_files += 1
continue
# 检查有效性
if info.is_valid:
report.valid_files += 1
else:
report.invalid_files += 1
continue # 无效文件不继续统计
# 关键修复:只有当结构确实含有目标阳离子时才计入统计
if not info.contains_target_cation:
continue # 不含目标阳离子的文件不继续统计
report.cation_containing_count += 1
for anion in info.anion_types:
report.anion_distribution[anion] = \
report.anion_distribution.get(anion, 0) + 1
if info.anion_mode == "single":
report.single_anion_count += 1
elif info.anion_mode == "mixed":
report.mixed_anion_count += 1
# 根据阴离子模式过滤
if self.anion_mode == "single" and info.anion_mode != "single":
continue
if self.anion_mode == "mixed" and info.anion_mode != "mixed":
continue
if info.anion_mode == "none":
continue
# 各项统计
if info.has_oxidation_states:
report.with_oxidation_states += 1
else:
report.without_oxidation_states += 1
# Li共占位统计修改
if info.cation_with_vacancy:
report.cation_with_vacancy_count += 1
if info.cation_with_other_cation:
report.cation_with_other_cation_count += 1
if info.anion_has_partial_occupancy:
report.anion_partial_occupancy_count += 1
if info.is_binary_compound:
report.binary_compound_count += 1
if info.has_water_molecule:
report.has_water_count += 1
if info.has_radioactive_elements:
report.has_radioactive_count += 1
# 可处理性
if info.can_process:
if info.needs_expansion:
report.needs_preprocessing += 1
else:
report.directly_processable += 1
else:
report.cannot_process += 1
if info.skip_reason:
for reason in info.skip_reason.split("; "):
report.skip_reasons_summary[reason] = \
report.skip_reasons_summary.get(reason, 0) + 1
# 扩胞统计(新增)
exp_info = info.expansion_info
factor = exp_info.expansion_factor
if not exp_info.needs_expansion:
report.expansion_stats['no_expansion_needed'] += 1
elif not exp_info.can_expand:
report.expansion_stats['cannot_expand'] += 1
elif factor == 2:
report.expansion_stats['expansion_factor_2'] += 1
elif factor == 3:
report.expansion_stats['expansion_factor_3'] += 1
elif 4 <= factor <= 8:
report.expansion_stats['expansion_factor_4_8'] += 1
else:
report.expansion_stats['expansion_factor_large'] += 1
# 详细分布
if exp_info.needs_expansion and exp_info.can_expand:
report.expansion_factor_distribution[factor] = \
report.expansion_factor_distribution.get(factor, 0) + 1
report.needs_expansion_count += 1
# 计算比例
if report.valid_files > 0:
report.cation_containing_ratio = \
report.cation_containing_count / report.valid_files
if report.cation_containing_count > 0:
for anion, count in report.anion_distribution.items():
report.anion_ratios[anion] = \
count / report.cation_containing_count

View File

@@ -0,0 +1,159 @@
"""
报告生成器:生成格式化的分析报告
"""
from typing import Optional
from .database_analyzer import DatabaseReport
class ReportGenerator:
"""报告生成器"""
@staticmethod
def print_report(report: DatabaseReport, detailed: bool = False):
"""打印分析报告"""
print("\n" + "=" * 70)
print(" 数据库分析报告")
print("=" * 70)
# 基础信息
print(f"\n📁 数据库路径: {report.database_path}")
print(f"🎯 目标阳离子: {report.target_cation}")
print(f"🎯 目标阴离子: {', '.join(sorted(report.target_anions))}")
print(f"🎯 阴离子模式: {report.anion_mode}")
# 基础统计
print("\n" + "-" * 70)
print("【1. 基础统计】")
print("-" * 70)
print(f" 总 CIF 文件数: {report.total_files}")
print(f" 有效文件数: {report.valid_files}")
print(f" 无效文件数: {report.invalid_files}")
print(f"{report.target_cation} 化合物数: {report.cation_containing_count}")
print(f"{report.target_cation} 化合物占比: {report.cation_containing_ratio:.1%}")
# 阴离子分布
print("\n" + "-" * 70)
print(f"【2. 阴离子分布】(在含 {report.target_cation} 的化合物中)")
print("-" * 70)
if report.anion_distribution:
for anion in sorted(report.anion_distribution.keys()):
count = report.anion_distribution[anion]
ratio = report.anion_ratios.get(anion, 0)
bar = "" * int(ratio * 30)
print(f" {anion:5s}: {count:6d} ({ratio:6.1%}) {bar}")
print(f"\n 单一阴离子化合物: {report.single_anion_count}")
print(f" 复合阴离子化合物: {report.mixed_anion_count}")
# 数据质量
print("\n" + "-" * 70)
print("【3. 数据质量检查】")
print("-" * 70)
total_target = report.cation_containing_count
if total_target > 0:
print(f" 含化合价信息: {report.with_oxidation_states:6d} "
f"({report.with_oxidation_states / total_target:.1%})")
print(f" 缺化合价信息: {report.without_oxidation_states:6d} "
f"({report.without_oxidation_states / total_target:.1%})")
print()
print(f" {report.target_cation}与空位共占位(无需处理): {report.cation_with_vacancy_count:6d}")
print(f" {report.target_cation}与阳离子共占位(需扩胞): {report.cation_with_other_cation_count:6d}")
print(f" 阴离子共占位: {report.anion_partial_occupancy_count:6d}")
print(f" 需扩胞处理(总计): {report.needs_expansion_count:6d}")
print()
print(f" 二元化合物: {report.binary_compound_count:6d}")
print(f" 含水分子: {report.has_water_count:6d}")
print(f" 含放射性元素: {report.has_radioactive_count:6d}")
# 可处理性评估
print("\n" + "-" * 70)
print("【4. 可处理性评估】")
print("-" * 70)
total_processable = report.directly_processable + report.needs_preprocessing
print(f" ✅ 可直接处理: {report.directly_processable:6d}")
print(f" ⚠️ 需预处理(扩胞): {report.needs_preprocessing:6d}")
print(f" ❌ 无法处理: {report.cannot_process:6d}")
print(f" ─────────────────────────────")
print(f" 📊 可处理总数: {total_processable:6d}")
# 跳过原因汇总
if report.skip_reasons_summary:
print("\n" + "-" * 70)
print("【5. 无法处理的原因统计】")
print("-" * 70)
sorted_reasons = sorted(
report.skip_reasons_summary.items(),
key=lambda x: x[1],
reverse=True
)
for reason, count in sorted_reasons:
print(f" {reason:35s}: {count:6d}")
# 扩胞分析
print("\n" + "-" * 70)
print("【6. 扩胞需求分析】")
print("-" * 70)
exp = report.expansion_stats
if total_processable > 0:
print(f" 无需扩胞: {exp['no_expansion_needed']:6d}")
print(f" 扩胞因子=2: {exp['expansion_factor_2']:6d}")
print(f" 扩胞因子=3: {exp['expansion_factor_3']:6d}")
print(f" 扩胞因子=4~8: {exp['expansion_factor_4_8']:6d}")
print(f" 扩胞因子>8: {exp['expansion_factor_large']:6d}")
print(f" 无法扩胞(因子过大): {exp['cannot_expand']:6d}")
# 详细分布
if detailed and report.expansion_factor_distribution:
print("\n 扩胞因子分布:")
for factor in sorted(report.expansion_factor_distribution.keys()):
count = report.expansion_factor_distribution[factor]
bar = "" * min(count, 30)
print(f" {factor:3d}x: {count:5d} {bar}")
print("\n" + "=" * 70)
@staticmethod
def export_to_csv(report: DatabaseReport, output_path: str):
"""导出详细结果到CSV"""
import csv
with open(output_path, 'w', newline='', encoding='utf-8') as f:
writer = csv.writer(f)
# 写入表头
headers = [
'file_name', 'is_valid', 'contains_target_cation',
'anion_types', 'anion_mode', 'has_oxidation_states',
'has_partial_occupancy', 'cation_partial_occupancy',
'anion_partial_occupancy', 'needs_expansion',
'is_binary', 'has_water', 'has_radioactive',
'can_process', 'skip_reason'
]
writer.writerow(headers)
# 写入数据
for info in report.all_structures:
row = [
info.file_name,
info.is_valid,
info.contains_target_cation,
'+'.join(sorted(info.anion_types)) if info.anion_types else '',
info.anion_mode,
info.has_oxidation_states,
info.has_partial_occupancy,
info.cation_with_other_cation, # 修复:使用正确的属性名
info.anion_has_partial_occupancy,
info.needs_expansion,
info.is_binary_compound,
info.has_water_molecule,
info.has_radioactive_elements,
info.can_process,
info.skip_reason
]
writer.writerow(row)
print(f"详细结果已导出到: {output_path}")

View File

@@ -0,0 +1,443 @@
"""
结构检查器对单个CIF文件进行深度分析含扩胞需求判断
"""
from dataclasses import dataclass, field
from typing import Set, Dict, List, Optional, Tuple
from pymatgen.core import Structure
from pymatgen.core.periodic_table import Element, Specie
from collections import defaultdict
from fractions import Fraction
from functools import reduce
import math
import re
import os
@dataclass
class OccupancyInfo:
"""共占位信息"""
occupation: float # 占据率
atom_serials: List[int] = field(default_factory=list) # 原子序号
elements: List[str] = field(default_factory=list) # 涉及的元素
numerator: int = 0 # 分子
denominator: int = 1 # 分母
involves_target_cation: bool = False # 是否涉及目标阳离子
involves_anion: bool = False # 是否涉及阴离子
@dataclass
class ExpansionInfo:
"""扩胞信息"""
needs_expansion: bool = False # 是否需要扩胞
expansion_factor: int = 1 # 扩胞因子(最小公倍数)
occupancy_details: List[OccupancyInfo] = field(default_factory=list) # 共占位详情
problematic_sites: int = 0 # 问题位点数
can_expand: bool = True # 是否可以扩胞处理
skip_reason: str = "" # 无法扩胞的原因
@dataclass
class StructureInfo:
"""单个结构的分析结果"""
file_path: str
file_name: str
# 基础信息
is_valid: bool = False
error_message: str = ""
# 元素组成
elements: Set[str] = field(default_factory=set)
num_sites: int = 0
formula: str = ""
# 阳离子/阴离子信息
contains_target_cation: bool = False
anion_types: Set[str] = field(default_factory=set)
anion_mode: str = "" # "single", "mixed", "none"
# 数据质量标记
has_oxidation_states: bool = False
has_partial_occupancy: bool = False # 是否有共占位
has_water_molecule: bool = False
has_radioactive_elements: bool = False
is_binary_compound: bool = False
# 共占位详细分析(新增)
cation_with_vacancy: bool = False # Li与空位共占位不需处理
cation_with_other_cation: bool = False # Li与其他阳离子共占位需扩胞
anion_has_partial_occupancy: bool = False # 阴离子共占位
other_has_partial_occupancy: bool = False # 其他元素共占位(需扩胞)
expansion_info: ExpansionInfo = field(default_factory=ExpansionInfo)
# 可处理性
needs_expansion: bool = False
can_process: bool = False
skip_reason: str = ""
class StructureInspector:
"""结构检查器(含扩胞分析)"""
# 预定义的阴离子集合
VALID_ANIONS = {'O', 'S', 'Cl', 'Br'}
# 放射性元素
RADIOACTIVE_ELEMENTS = {
'U', 'Th', 'Pu', 'Ra', 'Rn', 'Po', 'Np', 'Am',
'Cm', 'Bk', 'Cf', 'Es', 'Fm', 'Md', 'No', 'Lr'
}
# 扩胞精度模式
PRECISION_LIMITS = {
'high': None, # 精确分数
'normal': 100, # 分母≤100
'low': 10, # 分母≤10
'very_low': 5 # 分母≤5
}
def __init__(
self,
target_cation: str = "Li",
target_anions: Set[str] = None,
expansion_precision: str = "low"
):
"""
初始化检查器
Args:
target_cation: 目标阳离子 (如 "Li", "Na")
target_anions: 目标阴离子集合 (如 {"O", "S"})
expansion_precision: 扩胞计算精度 ('high', 'normal', 'low', 'very_low')
"""
self.target_cation = target_cation
self.target_anions = target_anions or self.VALID_ANIONS
self.expansion_precision = expansion_precision
# 目标阳离子的各种可能表示形式
self.target_cation_variants = {
target_cation,
f"{target_cation}+",
f"{target_cation}1+",
}
def inspect(self, file_path: str) -> StructureInfo:
"""分析单个CIF文件"""
import os
info = StructureInfo(
file_path=file_path,
file_name=os.path.basename(file_path)
)
# 尝试读取结构
try:
structure = Structure.from_file(file_path)
except Exception as e:
info.is_valid = False
info.error_message = f"读取CIF失败: {str(e)}"
return info
info.is_valid = True
try:
# ===== 关键修复:正确提取元素符号 =====
# structure.composition.elements 返回 Element 对象列表
# 需要用 .symbol 属性获取字符串
element_symbols = set()
for el in structure.composition.elements:
# el 是 Element 对象el.symbol 是字符串如 "Li"
element_symbols.add(el.symbol)
info.elements = element_symbols
info.num_sites = structure.num_sites
info.formula = structure.composition.reduced_formula
# 检查是否为二元化合物
info.is_binary_compound = len(element_symbols) == 2
# ===== 关键修复:直接比较字符串 =====
info.contains_target_cation = self.target_cation in element_symbols
# 检查阴离子类型
info.anion_types = element_symbols.intersection(self.target_anions)
if len(info.anion_types) == 0:
info.anion_mode = "none"
elif len(info.anion_types) == 1:
info.anion_mode = "single"
else:
info.anion_mode = "mixed"
# 检查氧化态
info.has_oxidation_states = self._check_oxidation_states(structure)
# 检查共占位(核心分析)
try:
self._analyze_partial_occupancy(structure, info)
except Exception as e:
# 共占位分析失败,记录但继续
pass
# 检查水分子
try:
info.has_water_molecule = self._check_water_molecule(structure)
except:
info.has_water_molecule = False
# 检查放射性元素
info.has_radioactive_elements = bool(
info.elements.intersection(self.RADIOACTIVE_ELEMENTS)
)
# 判断可处理性
self._evaluate_processability(info)
except Exception as e:
# 分析过程出错,但文件本身是有效的
# 保留 is_valid = True但记录错误
info.error_message = f"分析过程出错: {str(e)}"
return info
def _check_oxidation_states(self, structure: Structure) -> bool:
"""检查结构是否包含氧化态信息"""
try:
for site in structure.sites:
for specie in site.species.keys():
if isinstance(specie, Specie):
return True
return False
except:
return False
def _get_element_from_species_string(self, species_str: str) -> str:
"""从物种字符串提取纯元素符号"""
match = re.match(r'([A-Z][a-z]?)', species_str)
return match.group(1) if match else ""
def _get_occupancy_from_species_string(self, species_str: str, exclude_elements: Set[str]) -> Optional[float]:
"""
从物种字符串获取非目标元素的占据率
格式如: "Li+:0.689, Sc3+:0.311"
"""
if ':' not in species_str:
return None
parts = [p.strip() for p in species_str.split(',')]
for part in parts:
if ':' in part:
element_part, occu_part = part.split(':')
element = self._get_element_from_species_string(element_part.strip())
if element and element not in exclude_elements:
try:
return float(occu_part.strip())
except ValueError:
continue
return None
# 在 StructureInspector 类中,替换 _analyze_partial_occupancy 方法
def _analyze_partial_occupancy(self, structure: Structure, info: StructureInfo):
"""
分析共占位情况(修正版)
关键规则:
- Li与空位共占位 → 不需要处理cation_with_vacancy
- Li与其他阳离子共占位 → 需要扩胞cation_with_other_cation
- 阴离子共占位 → 通常不处理
- 其他阳离子共占位 → 需要扩胞
"""
occupancy_dict = defaultdict(list) # {occupation: [site_indices]}
occupancy_elements = {} # {occupation: [elements]}
for i, site in enumerate(structure.sites):
site_species = site.species
species_string = str(site.species)
# 提取各元素及其占据率
species_occu = {} # {element: occupancy}
for sp, occu in site_species.items():
elem = sp.symbol if hasattr(sp, 'symbol') else str(sp)
elem = self._get_element_from_species_string(elem)
if elem:
species_occu[elem] = occu
total_occupancy = sum(species_occu.values())
elements_at_site = list(species_occu.keys())
# 检查是否有部分占据
has_partial = any(occu < 1.0 for occu in species_occu.values()) or len(species_occu) > 1
if not has_partial:
continue
info.has_partial_occupancy = True
# 判断Li的共占位情况
if self.target_cation in elements_at_site:
li_occu = species_occu.get(self.target_cation, 0)
other_elements = [e for e in elements_at_site if e != self.target_cation]
if not other_elements and li_occu < 1.0:
# Li与空位共占位Li占据率<1但没有其他元素
info.cation_with_vacancy = True
elif other_elements:
# Li与其他元素共占位
other_are_anions = all(e in self.target_anions for e in other_elements)
if other_are_anions:
# Li与阴离子共占位罕见标记为阴离子共占位
info.anion_has_partial_occupancy = True
else:
# Li与其他阳离子共占位 → 需要扩胞
info.cation_with_other_cation = True
# 记录需要扩胞的占据率取非Li元素的占据率
for elem in other_elements:
if elem not in self.target_anions:
occu = species_occu.get(elem, 0)
if occu > 0 and occu < 1.0:
occupancy_dict[occu].append(i)
occupancy_elements[occu] = elements_at_site
else:
# 不涉及Li的位点
# 判断是否涉及阴离子
if any(elem in self.target_anions for elem in elements_at_site):
info.anion_has_partial_occupancy = True
else:
# 其他阳离子的共占位 → 需要扩胞
info.other_has_partial_occupancy = True
# 获取占据率
for elem, occu in species_occu.items():
if occu > 0 and occu < 1.0:
occupancy_dict[occu].append(i)
occupancy_elements[occu] = elements_at_site
break # 只记录一次
# 计算扩胞信息
self._calculate_expansion_info(info, occupancy_dict, occupancy_elements)
def _evaluate_processability(self, info: StructureInfo):
"""评估可处理性(修正版)"""
skip_reasons = []
if not info.is_valid:
skip_reasons.append("无法解析CIF文件")
if not info.contains_target_cation:
skip_reasons.append(f"不含{self.target_cation}")
if info.anion_mode == "none":
skip_reasons.append("不含目标阴离子")
if info.is_binary_compound:
skip_reasons.append("二元化合物")
if info.has_radioactive_elements:
skip_reasons.append("含放射性元素")
# Li与空位共占位 → 不需要处理不加入skip_reasons
# info.cation_with_vacancy 不影响可处理性
# Li与其他阳离子共占位 → 需要扩胞(如果扩胞因子合理则可处理)
if info.cation_with_other_cation:
if info.expansion_info.can_expand:
info.needs_expansion = True
else:
skip_reasons.append(f"{self.target_cation}与其他阳离子共占位且{info.expansion_info.skip_reason}")
# 阴离子共占位 → 不处理
if info.anion_has_partial_occupancy:
skip_reasons.append("阴离子存在共占位")
if info.has_water_molecule:
skip_reasons.append("含水分子")
# 其他阳离子共占位不涉及Li→ 需要扩胞
if info.other_has_partial_occupancy:
if info.expansion_info.can_expand:
info.needs_expansion = True
else:
skip_reasons.append(info.expansion_info.skip_reason)
if skip_reasons:
info.can_process = False
info.skip_reason = "; ".join(skip_reasons)
else:
info.can_process = True
def _calculate_expansion_info(
self,
info: StructureInfo,
occupancy_dict: Dict[float, List[int]],
occupancy_elements: Dict[float, List[str]]
):
"""计算扩胞相关信息"""
expansion_info = ExpansionInfo()
if not occupancy_dict:
info.expansion_info = expansion_info
return
# 需要扩胞(有非目标阳离子的共占位)
expansion_info.needs_expansion = True
expansion_info.problematic_sites = sum(len(v) for v in occupancy_dict.values())
# 转换为OccupancyInfo列表
occupancy_list = []
for occu, serials in occupancy_dict.items():
elements = occupancy_elements.get(occu, [])
# 根据精度计算分数
limit = self.PRECISION_LIMITS.get(self.expansion_precision)
if limit:
fraction = Fraction(occu).limit_denominator(limit)
else:
fraction = Fraction(occu).limit_denominator()
occ_info = OccupancyInfo(
occupation=occu,
atom_serials=[s + 1 for s in serials], # 转为1-based
elements=elements,
numerator=fraction.numerator,
denominator=fraction.denominator,
involves_target_cation=self.target_cation in elements,
involves_anion=any(e in self.target_anions for e in elements)
)
occupancy_list.append(occ_info)
expansion_info.occupancy_details = occupancy_list
# 计算最小公倍数(扩胞因子)
denominators = [occ.denominator for occ in occupancy_list]
if denominators:
lcm = reduce(lambda a, b: a * b // math.gcd(a, b), denominators, 1)
expansion_info.expansion_factor = lcm
# 判断是否可以扩胞(因子过大则不可处理)
if lcm > 64: # 扩胞超过64倍通常不可行
expansion_info.can_expand = False
expansion_info.skip_reason = f"扩胞因子过大({lcm})"
info.expansion_info = expansion_info
info.needs_expansion = expansion_info.needs_expansion and expansion_info.can_expand
def _check_water_molecule(self, structure: Structure) -> bool:
"""检查是否含有水分子"""
try:
oxygen_sites = []
hydrogen_sites = []
for site in structure.sites:
species_str = str(site.species)
if 'O' in species_str:
oxygen_sites.append(site)
if 'H' in species_str:
hydrogen_sites.append(site)
for o_site in oxygen_sites:
nearby_h = [h for h in hydrogen_sites if o_site.distance(h) < 1.2]
if len(nearby_h) >= 2:
return True
return False
except:
return False

199
src/analysis/worker.py Normal file
View File

@@ -0,0 +1,199 @@
"""
工作进程:处理单个分析任务
设计为可以独立运行用于SLURM作业数组
"""
import os
import pickle
from typing import List, Tuple, Optional
from dataclasses import asdict, fields
from .structure_inspector import StructureInspector, StructureInfo
def analyze_single_file(args: Tuple[str, str, set]) -> Optional[StructureInfo]:
"""
分析单个CIF文件Worker函数
Args:
args: (file_path, target_cation, target_anions)
Returns:
StructureInfo 或 None如果分析失败
"""
file_path, target_cation, target_anions = args
try:
inspector = StructureInspector(
target_cation=target_cation,
target_anions=target_anions
)
result = inspector.inspect(file_path)
return result
except Exception as e:
# 返回一个标记失败的结果(而不是 None
return StructureInfo(
file_path=file_path,
file_name=os.path.basename(file_path),
is_valid=False,
error_message=f"Worker异常: {str(e)}"
)
def structure_info_to_dict(info: StructureInfo) -> dict:
"""
将 StructureInfo 转换为可序列化的字典
处理 set、dataclass 等特殊类型
"""
result = {}
for field in fields(info):
value = getattr(info, field.name)
# 处理 set 类型
if isinstance(value, set):
result[field.name] = list(value)
# 处理嵌套的 dataclass (如 ExpansionInfo)
elif hasattr(value, '__dataclass_fields__'):
result[field.name] = asdict(value)
# 处理 list 中可能包含的 dataclass
elif isinstance(value, list):
result[field.name] = [
asdict(item) if hasattr(item, '__dataclass_fields__') else item
for item in value
]
else:
result[field.name] = value
return result
def dict_to_structure_info(d: dict) -> StructureInfo:
"""
从字典恢复 StructureInfo 对象
"""
from .structure_inspector import ExpansionInfo, OccupancyInfo
# 处理 set 类型字段
if 'elements' in d and isinstance(d['elements'], list):
d['elements'] = set(d['elements'])
if 'anion_types' in d and isinstance(d['anion_types'], list):
d['anion_types'] = set(d['anion_types'])
if 'target_anions' in d and isinstance(d['target_anions'], list):
d['target_anions'] = set(d['target_anions'])
# 处理 ExpansionInfo
if 'expansion_info' in d and isinstance(d['expansion_info'], dict):
exp_dict = d['expansion_info']
# 处理 OccupancyInfo 列表
if 'occupancy_details' in exp_dict:
exp_dict['occupancy_details'] = [
OccupancyInfo(**occ) if isinstance(occ, dict) else occ
for occ in exp_dict['occupancy_details']
]
d['expansion_info'] = ExpansionInfo(**exp_dict)
return StructureInfo(**d)
def batch_analyze(
file_paths: List[str],
target_cation: str,
target_anions: set,
output_file: str = None
) -> List[StructureInfo]:
"""
批量分析文件用于SLURM子任务
Args:
file_paths: CIF文件路径列表
target_cation: 目标阳离子
target_anions: 目标阴离子集合
output_file: 输出文件路径pickle格式
Returns:
StructureInfo列表
"""
results = []
inspector = StructureInspector(
target_cation=target_cation,
target_anions=target_anions
)
for file_path in file_paths:
try:
info = inspector.inspect(file_path)
results.append(info)
except Exception as e:
results.append(StructureInfo(
file_path=file_path,
file_name=os.path.basename(file_path),
is_valid=False,
error_message=str(e)
))
# 保存结果
if output_file:
serializable_results = [structure_info_to_dict(r) for r in results]
with open(output_file, 'wb') as f:
pickle.dump(serializable_results, f)
return results
def load_results(result_file: str) -> List[StructureInfo]:
"""
从pickle文件加载结果
"""
with open(result_file, 'rb') as f:
data = pickle.load(f)
return [dict_to_structure_info(d) for d in data]
def merge_results(result_files: List[str]) -> List[StructureInfo]:
"""
合并多个结果文件用于汇总SLURM作业数组的输出
"""
all_results = []
for f in result_files:
if os.path.exists(f):
all_results.extend(load_results(f))
return all_results
# 用于SLURM作业数组的命令行入口
if __name__ == "__main__":
import argparse
import json
parser = argparse.ArgumentParser(description="CIF Analysis Worker")
parser.add_argument("--tasks-file", required=True, help="任务文件路径(JSON)")
parser.add_argument("--output-dir", required=True, help="输出目录")
parser.add_argument("--task-id", type=int, default=0, help="任务ID(用于数组作业)")
parser.add_argument("--num-workers", type=int, default=1, help="并行worker数")
args = parser.parse_args()
# 加载任务
with open(args.tasks_file, 'r') as f:
task_config = json.load(f)
file_paths = task_config['files']
target_cation = task_config['target_cation']
target_anions = set(task_config['target_anions'])
# 如果是数组作业,只处理分配的部分
if args.task_id >= 0:
chunk_size = len(file_paths) // args.num_workers + 1
start_idx = args.task_id * chunk_size
end_idx = min(start_idx + chunk_size, len(file_paths))
file_paths = file_paths[start_idx:end_idx]
# 输出文件
output_file = os.path.join(args.output_dir, f"results_{args.task_id}.pkl")
# 执行分析
print(f"Worker {args.task_id}: 处理 {len(file_paths)} 个文件")
results = batch_analyze(file_paths, target_cation, target_anions, output_file)
print(f"Worker {args.task_id}: 完成,结果保存到 {output_file}")

View File

@@ -0,0 +1,15 @@
"""
计算模块Zeo++ Voronoi 分析
"""
from .workspace_manager import WorkspaceManager
from .zeo_executor import ZeoExecutor, ZeoConfig
from .result_processor import ResultProcessor, FilterCriteria, StructureResult
__all__ = [
'WorkspaceManager',
'ZeoExecutor',
'ZeoConfig',
'ResultProcessor',
'FilterCriteria',
'StructureResult'
]

View File

@@ -0,0 +1,426 @@
"""
Zeo++ 计算结果处理器:提取数据、筛选结构
"""
import os
import re
import shutil
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass, field
import pandas as pd
@dataclass
class FilterCriteria:
"""筛选条件"""
min_percolation_diameter: float = 1.0 # 最小渗透直径 (Å),默认 1.0
min_d_value: float = 2.0 # 最小 d 值,默认 2.0
max_node_length: float = float('inf') # 最大节点长度 (Å)
@dataclass
class StructureResult:
"""单个结构的计算结果"""
structure_name: str
anion_type: str
work_dir: str
# 提取的参数
percolation_diameter: Optional[float] = None
min_d: Optional[float] = None
max_node_length: Optional[float] = None
# 筛选结果
passed_filter: bool = False
filter_reason: str = ""
class ResultProcessor:
"""
Zeo++ 计算结果处理器
功能:
1. 从每个结构目录的 log.txt 提取关键参数
2. 汇总所有结果到 CSV 文件
3. 根据筛选条件筛选结构
4. 将通过筛选的结构复制到新文件夹
"""
def __init__(
self,
workspace_path: str = "workspace",
data_dir: str = None,
output_dir: str = None
):
"""
初始化结果处理器
Args:
workspace_path: 工作区根目录
data_dir: 数据目录(默认 workspace/data
output_dir: 输出目录(默认 workspace/results
"""
self.workspace_path = os.path.abspath(workspace_path)
self.data_dir = data_dir or os.path.join(self.workspace_path, "data")
self.output_dir = output_dir or os.path.join(self.workspace_path, "results")
def extract_from_log(self, log_path: str) -> Tuple[Optional[float], Optional[float], Optional[float]]:
"""
从 log.txt 中提取三个关键参数
Args:
log_path: log.txt 文件路径
Returns:
(percolation_diameter, min_d, max_node_length)
"""
if not os.path.exists(log_path):
return None, None, None
try:
with open(log_path, 'r', encoding='utf-8') as f:
content = f.read()
except Exception:
return None, None, None
# 正则表达式 - 与 py/extract_data.py 保持一致
# 1. Percolation diameter: "# Percolation diameter (A): 1.06"
re_percolation = r"Percolation diameter \(A\):\s*([\d\.]+)"
# 2. Minimum of d: "the minium of d\n3.862140561244235"
# 注意:这是 Topological_Analysis 库输出的格式
re_min_d = r"the minium of d\s*\n\s*([\d\.]+)"
# 3. Maximum node length: "# Maximum node length detected: 1.332 A"
re_max_node = r"Maximum node length detected:\s*([\d\.]+)\s*A"
# 提取数据
match_perc = re.search(re_percolation, content)
match_d = re.search(re_min_d, content)
match_node = re.search(re_max_node, content)
val_perc = float(match_perc.group(1)) if match_perc else None
val_d = float(match_d.group(1)) if match_d else None
val_node = float(match_node.group(1)) if match_node else None
return val_perc, val_d, val_node
def process_all_structures(self) -> List[StructureResult]:
"""
处理所有结构,提取计算结果
Returns:
StructureResult 列表
"""
results = []
if not os.path.exists(self.data_dir):
print(f"⚠️ 数据目录不存在: {self.data_dir}")
return results
print("\n正在提取计算结果...")
# 遍历阴离子目录
for anion_key in os.listdir(self.data_dir):
anion_dir = os.path.join(self.data_dir, anion_key)
if not os.path.isdir(anion_dir):
continue
# 遍历结构目录
for struct_name in os.listdir(anion_dir):
struct_dir = os.path.join(anion_dir, struct_name)
if not os.path.isdir(struct_dir):
continue
# 查找 log.txt
log_path = os.path.join(struct_dir, "log.txt")
# 提取参数
perc, min_d, max_node = self.extract_from_log(log_path)
result = StructureResult(
structure_name=struct_name,
anion_type=anion_key,
work_dir=struct_dir,
percolation_diameter=perc,
min_d=min_d,
max_node_length=max_node
)
results.append(result)
print(f" 共处理 {len(results)} 个结构")
return results
def apply_filter(
self,
results: List[StructureResult],
criteria: FilterCriteria
) -> List[StructureResult]:
"""
应用筛选条件
Args:
results: 结构结果列表
criteria: 筛选条件
Returns:
更新后的结果列表(包含筛选状态)
"""
print("\n应用筛选条件...")
print(f" 最小渗透直径: {criteria.min_percolation_diameter} Å")
print(f" 最小 d 值: {criteria.min_d_value}")
print(f" 最大节点长度: {criteria.max_node_length} Å")
passed_count = 0
for result in results:
# 检查是否有有效数据
if result.percolation_diameter is None or result.min_d is None:
result.passed_filter = False
result.filter_reason = "数据缺失"
continue
# 检查渗透直径
if result.percolation_diameter < criteria.min_percolation_diameter:
result.passed_filter = False
result.filter_reason = f"渗透直径 {result.percolation_diameter:.3f} < {criteria.min_percolation_diameter}"
continue
# 检查 d 值
if result.min_d < criteria.min_d_value:
result.passed_filter = False
result.filter_reason = f"d 值 {result.min_d:.3f} < {criteria.min_d_value}"
continue
# 检查节点长度(如果有数据)
if result.max_node_length is not None:
if result.max_node_length > criteria.max_node_length:
result.passed_filter = False
result.filter_reason = f"节点长度 {result.max_node_length:.3f} > {criteria.max_node_length}"
continue
# 通过所有筛选
result.passed_filter = True
result.filter_reason = "通过"
passed_count += 1
print(f" 通过筛选: {passed_count}/{len(results)}")
return results
def save_summary_csv(
self,
results: List[StructureResult],
output_path: str = None
) -> str:
"""
保存汇总 CSV 文件
Args:
results: 结构结果列表
output_path: 输出路径(默认 workspace/results/summary.csv
Returns:
CSV 文件路径
"""
if output_path is None:
os.makedirs(self.output_dir, exist_ok=True)
output_path = os.path.join(self.output_dir, "summary.csv")
# 构建数据
data = []
for r in results:
data.append({
'Structure': r.structure_name,
'Anion_Type': r.anion_type,
'Percolation_Diameter_A': r.percolation_diameter,
'Min_d': r.min_d,
'Max_Node_Length_A': r.max_node_length,
'Passed_Filter': r.passed_filter,
'Filter_Reason': r.filter_reason
})
df = pd.DataFrame(data)
# 按阴离子类型和结构名排序
df = df.sort_values(['Anion_Type', 'Structure'])
# 保存
os.makedirs(os.path.dirname(output_path), exist_ok=True)
df.to_csv(output_path, index=False)
print(f"\n汇总 CSV 已保存: {output_path}")
return output_path
def save_anion_csv(
self,
results: List[StructureResult],
output_dir: str = None
) -> List[str]:
"""
按阴离子类型分别保存 CSV 文件
Args:
results: 结构结果列表
output_dir: 输出目录
Returns:
生成的 CSV 文件路径列表
"""
if output_dir is None:
output_dir = self.output_dir
# 按阴离子类型分组
anion_groups: Dict[str, List[StructureResult]] = {}
for r in results:
if r.anion_type not in anion_groups:
anion_groups[r.anion_type] = []
anion_groups[r.anion_type].append(r)
csv_files = []
for anion_type, group_results in anion_groups.items():
# 构建数据
data = []
for r in group_results:
data.append({
'Structure': r.structure_name,
'Percolation_Diameter_A': r.percolation_diameter,
'Min_d': r.min_d,
'Max_Node_Length_A': r.max_node_length,
'Passed_Filter': r.passed_filter,
'Filter_Reason': r.filter_reason
})
df = pd.DataFrame(data)
df = df.sort_values('Structure')
# 保存到对应目录
anion_output_dir = os.path.join(output_dir, anion_type)
os.makedirs(anion_output_dir, exist_ok=True)
csv_path = os.path.join(anion_output_dir, f"{anion_type}.csv")
df.to_csv(csv_path, index=False)
csv_files.append(csv_path)
print(f" {anion_type}: {len(group_results)} 个结构 -> {csv_path}")
return csv_files
def copy_passed_structures(
self,
results: List[StructureResult],
output_dir: str = None
) -> int:
"""
将通过筛选的结构复制到新文件夹
Args:
results: 结构结果列表
output_dir: 输出目录(默认 workspace/passed
Returns:
复制的结构数量
"""
if output_dir is None:
output_dir = os.path.join(self.workspace_path, "passed")
passed_results = [r for r in results if r.passed_filter]
if not passed_results:
print("\n没有通过筛选的结构")
return 0
print(f"\n正在复制 {len(passed_results)} 个通过筛选的结构...")
copied = 0
for r in passed_results:
# 目标目录passed/阴离子类型/结构名/
dst_dir = os.path.join(output_dir, r.anion_type, r.structure_name)
try:
# 如果目标已存在,先删除
if os.path.exists(dst_dir):
shutil.rmtree(dst_dir)
# 复制整个目录
shutil.copytree(r.work_dir, dst_dir)
copied += 1
except Exception as e:
print(f" ⚠️ 复制失败 {r.structure_name}: {e}")
print(f" 已复制 {copied} 个结构到: {output_dir}")
return copied
def process_and_filter(
self,
criteria: FilterCriteria = None,
save_csv: bool = True,
copy_passed: bool = True
) -> Tuple[List[StructureResult], Dict]:
"""
完整的处理流程:提取数据 -> 筛选 -> 保存 CSV -> 复制通过的结构
Args:
criteria: 筛选条件(如果为 None则不筛选
save_csv: 是否保存 CSV
copy_passed: 是否复制通过筛选的结构
Returns:
(结果列表, 统计信息字典)
"""
# 1. 提取所有结构的计算结果
results = self.process_all_structures()
if not results:
return results, {'total': 0, 'passed': 0, 'failed': 0}
# 2. 应用筛选条件
if criteria is not None:
results = self.apply_filter(results, criteria)
# 3. 保存 CSV
if save_csv:
print("\n保存结果 CSV...")
self.save_summary_csv(results)
self.save_anion_csv(results)
# 4. 复制通过筛选的结构
if copy_passed and criteria is not None:
self.copy_passed_structures(results)
# 统计
stats = {
'total': len(results),
'passed': sum(1 for r in results if r.passed_filter),
'failed': sum(1 for r in results if not r.passed_filter),
'missing_data': sum(1 for r in results if r.filter_reason == "数据缺失")
}
return results, stats
def print_summary(self, results: List[StructureResult], stats: Dict):
"""打印结果摘要"""
print("\n" + "=" * 60)
print("【计算结果摘要】")
print("=" * 60)
print(f" 总结构数: {stats['total']}")
print(f" 通过筛选: {stats['passed']}")
print(f" 未通过筛选: {stats['failed']}")
print(f" 数据缺失: {stats.get('missing_data', 0)}")
# 按阴离子类型统计
anion_stats: Dict[str, Dict] = {}
for r in results:
if r.anion_type not in anion_stats:
anion_stats[r.anion_type] = {'total': 0, 'passed': 0}
anion_stats[r.anion_type]['total'] += 1
if r.passed_filter:
anion_stats[r.anion_type]['passed'] += 1
print("\n 按阴离子类型:")
for anion, s in sorted(anion_stats.items()):
print(f" {anion}: {s['passed']}/{s['total']} 通过")
print("=" * 60)

View File

@@ -0,0 +1,288 @@
"""
工作区管理器:管理计算工作区的创建和软链接
"""
import os
import shutil
from pathlib import Path
from typing import Dict, List, Optional, Set, Tuple
from dataclasses import dataclass, field
@dataclass
class WorkspaceInfo:
"""工作区信息"""
workspace_path: str
data_dir: str # workspace/data
tool_dir: str # tool 目录
target_cation: str
target_anions: Set[str]
# 统计信息
total_structures: int = 0
anion_counts: Dict[str, int] = field(default_factory=dict)
linked_structures: int = 0 # 已创建软链接的结构数
class WorkspaceManager:
"""
工作区管理器
负责:
1. 检测现有工作区
2. 创建软链接yaml 配置文件和计算脚本放在每个结构目录下)
3. 准备计算任务
"""
# 支持的阴离子及其配置文件
SUPPORTED_ANIONS = {'O', 'S', 'Cl', 'Br'}
def __init__(
self,
workspace_path: str = "workspace",
tool_dir: str = "tool",
target_cation: str = "Li"
):
"""
初始化工作区管理器
Args:
workspace_path: 工作区根目录
tool_dir: 工具目录(包含 yaml 配置和计算脚本)
target_cation: 目标阳离子
"""
self.workspace_path = os.path.abspath(workspace_path)
self.tool_dir = os.path.abspath(tool_dir)
self.target_cation = target_cation
# 数据目录
self.data_dir = os.path.join(self.workspace_path, "data")
def check_existing_workspace(self) -> Optional[WorkspaceInfo]:
"""
检查现有工作区
Returns:
WorkspaceInfo 如果存在,否则 None
"""
if not os.path.exists(self.data_dir):
return None
# 扫描数据目录
anion_counts = {}
total = 0
linked = 0
for item in os.listdir(self.data_dir):
item_path = os.path.join(self.data_dir, item)
if os.path.isdir(item_path):
# 可能是阴离子目录(如 O, S, O+S
# 统计其中的结构数量
count = 0
for sub_item in os.listdir(item_path):
sub_path = os.path.join(item_path, sub_item)
if os.path.isdir(sub_path):
# 检查是否包含 CIF 文件
cif_files = [f for f in os.listdir(sub_path) if f.endswith('.cif')]
if cif_files:
count += 1
# 检查是否已有软链接
yaml_files = [f for f in os.listdir(sub_path) if f.endswith('.yaml')]
if yaml_files:
linked += 1
if count > 0:
anion_counts[item] = count
total += count
if total == 0:
return None
return WorkspaceInfo(
workspace_path=self.workspace_path,
data_dir=self.data_dir,
tool_dir=self.tool_dir,
target_cation=self.target_cation,
target_anions=set(anion_counts.keys()),
total_structures=total,
anion_counts=anion_counts,
linked_structures=linked
)
def setup_workspace(
self,
target_anions: Set[str] = None,
force_relink: bool = False
) -> WorkspaceInfo:
"""
设置工作区:在每个结构目录下创建软链接
软链接规则:
- yaml 文件:使用与阴离子目录同名的 yaml如 O 目录用 O.yamlCl+O 目录用 Cl+O.yaml
- python 脚本analyze_voronoi_nodes.py
Args:
target_anions: 目标阴离子集合
force_relink: 是否强制重新创建软链接
Returns:
WorkspaceInfo
"""
if target_anions is None:
target_anions = self.SUPPORTED_ANIONS
# 确保数据目录存在
if not os.path.exists(self.data_dir):
raise FileNotFoundError(f"数据目录不存在: {self.data_dir}")
# 获取计算脚本路径
analyze_script = os.path.join(self.tool_dir, "analyze_voronoi_nodes.py")
if not os.path.exists(analyze_script):
raise FileNotFoundError(f"计算脚本不存在: {analyze_script}")
anion_counts = {}
total = 0
linked = 0
print("\n正在设置工作区软链接...")
# 遍历数据目录中的阴离子子目录
for anion_key in os.listdir(self.data_dir):
anion_dir = os.path.join(self.data_dir, anion_key)
if not os.path.isdir(anion_dir):
continue
# 确定使用哪个 yaml 配置文件
# 使用与阴离子目录同名的 yaml 文件(如 O.yaml, Cl+O.yaml
yaml_name = f"{anion_key}.yaml"
yaml_source = os.path.join(self.tool_dir, self.target_cation, yaml_name)
if not os.path.exists(yaml_source):
print(f" ⚠️ 配置文件不存在: {yaml_source}")
continue
# 统计并处理该阴离子目录下的所有结构
count = 0
for struct_name in os.listdir(anion_dir):
struct_dir = os.path.join(anion_dir, struct_name)
if not os.path.isdir(struct_dir):
continue
# 检查是否包含 CIF 文件
cif_files = [f for f in os.listdir(struct_dir) if f.endswith('.cif')]
if not cif_files:
continue
count += 1
# 在结构目录下创建软链接
yaml_link = os.path.join(struct_dir, yaml_name)
script_link = os.path.join(struct_dir, "analyze_voronoi_nodes.py")
# 创建 yaml 软链接
if os.path.exists(yaml_link) or os.path.islink(yaml_link):
if force_relink:
os.remove(yaml_link)
os.symlink(yaml_source, yaml_link)
linked += 1
else:
os.symlink(yaml_source, yaml_link)
linked += 1
# 创建计算脚本软链接
if os.path.exists(script_link) or os.path.islink(script_link):
if force_relink:
os.remove(script_link)
os.symlink(analyze_script, script_link)
else:
os.symlink(analyze_script, script_link)
if count > 0:
anion_counts[anion_key] = count
total += count
print(f"{anion_key}: {count} 个结构, 配置 -> {yaml_name}")
print(f"\n 总计: {total} 个结构, 新建软链接: {linked}")
return WorkspaceInfo(
workspace_path=self.workspace_path,
data_dir=self.data_dir,
tool_dir=self.tool_dir,
target_cation=self.target_cation,
target_anions=set(anion_counts.keys()),
total_structures=total,
anion_counts=anion_counts,
linked_structures=linked
)
def get_computation_tasks(
self,
workspace_info: WorkspaceInfo = None
) -> List[Dict]:
"""
获取所有计算任务
Returns:
任务列表,每个任务包含:
- cif_path: CIF 文件路径
- yaml_name: YAML 配置文件名(如 O.yaml
- work_dir: 工作目录(结构目录)
- anion_type: 阴离子类型
- structure_name: 结构名称
"""
if workspace_info is None:
workspace_info = self.check_existing_workspace()
if workspace_info is None:
return []
tasks = []
for anion_key in workspace_info.anion_counts.keys():
anion_dir = os.path.join(self.data_dir, anion_key)
yaml_name = f"{anion_key}.yaml"
# 遍历该阴离子目录下的所有结构
for struct_name in os.listdir(anion_dir):
struct_dir = os.path.join(anion_dir, struct_name)
if not os.path.isdir(struct_dir):
continue
# 查找 CIF 文件
cif_files = [f for f in os.listdir(struct_dir) if f.endswith('.cif')]
# 检查是否有 yaml 软链接
yaml_path = os.path.join(struct_dir, yaml_name)
if not os.path.exists(yaml_path):
continue
for cif_file in cif_files:
cif_path = os.path.join(struct_dir, cif_file)
tasks.append({
'cif_path': cif_path,
'yaml_name': yaml_name,
'work_dir': struct_dir,
'anion_type': anion_key,
'structure_name': struct_name,
'cif_name': cif_file
})
return tasks
def print_workspace_summary(self, workspace_info: WorkspaceInfo):
"""打印工作区摘要"""
print("\n" + "=" * 60)
print("【工作区摘要】")
print("=" * 60)
print(f" 工作区路径: {workspace_info.workspace_path}")
print(f" 数据目录: {workspace_info.data_dir}")
print(f" 目标阳离子: {workspace_info.target_cation}")
print(f" 总结构数: {workspace_info.total_structures}")
print(f" 已配置软链接: {workspace_info.linked_structures}")
print()
print(" 阴离子分布:")
for anion, count in sorted(workspace_info.anion_counts.items()):
print(f" - {anion}: {count} 个结构")
print("=" * 60)

View File

@@ -0,0 +1,446 @@
"""
Zeo++ 计算执行器:使用 SLURM 作业数组高效调度大量计算任务
"""
import os
import subprocess
import time
import json
import tempfile
from typing import List, Dict, Optional, Callable, Any
from dataclasses import dataclass, field
from enum import Enum
import threading
from ..core.progress import ProgressManager
@dataclass
class ZeoConfig:
"""Zeo++ 计算配置"""
# 环境配置
conda_env: str = "/cluster/home/koko125/anaconda3/envs/zeo"
# SLURM 配置
partition: str = "cpu"
time_limit: str = "2:00:00" # 单个任务时间限制
memory_per_task: str = "4G"
# 作业数组配置
max_array_size: int = 1000 # SLURM 作业数组最大大小
max_concurrent: int = 50 # 最大并发任务数
# 轮询配置
poll_interval: float = 5.0 # 状态检查间隔(秒)
# 过滤器配置
filters: List[str] = field(default_factory=lambda: [
"Ordered", "PropOxi", "VoroPerco", "Coulomb", "VoroBV", "VoroInfo", "MergeSite"
])
@dataclass
class ZeoTaskResult:
"""单个任务结果"""
task_id: int
structure_name: str
cif_path: str
success: bool
output_files: List[str] = field(default_factory=list)
error_message: str = ""
duration: float = 0.0
class ZeoExecutor:
"""
Zeo++ 计算执行器
使用 SLURM 作业数组高效调度大量 Voronoi 分析任务
"""
def __init__(self, config: ZeoConfig = None):
self.config = config or ZeoConfig()
self.progress_manager = None
self._stop_event = threading.Event()
def run_batch(
self,
tasks: List[Dict],
output_dir: str = None,
desc: str = "Zeo++ 计算"
) -> List[ZeoTaskResult]:
"""
批量执行 Zeo++ 计算
Args:
tasks: 任务列表,每个任务包含 cif_path, yaml_path, work_dir 等
output_dir: SLURM 日志输出目录
desc: 进度条描述
Returns:
ZeoTaskResult 列表
"""
if not tasks:
print("⚠️ 没有任务需要执行")
return []
total = len(tasks)
# 创建输出目录
if output_dir is None:
output_dir = os.path.join(os.getcwd(), "slurm_logs")
os.makedirs(output_dir, exist_ok=True)
print(f"\n{'='*60}")
print(f"【Zeo++ 批量计算】")
print(f"{'='*60}")
print(f" 总任务数: {total}")
print(f" Conda环境: {self.config.conda_env}")
print(f" SLURM分区: {self.config.partition}")
print(f" 最大并发: {self.config.max_concurrent}")
print(f" 日志目录: {output_dir}")
print(f"{'='*60}\n")
# 保存任务列表到文件
tasks_file = os.path.join(output_dir, "tasks.json")
with open(tasks_file, 'w') as f:
json.dump(tasks, f, indent=2)
# 生成并提交作业数组
if total <= self.config.max_array_size:
# 单个作业数组
return self._submit_array_job(tasks, output_dir, desc)
else:
# 分批提交多个作业数组
return self._submit_batched_arrays(tasks, output_dir, desc)
def _submit_array_job(
self,
tasks: List[Dict],
output_dir: str,
desc: str
) -> List[ZeoTaskResult]:
"""提交单个作业数组"""
total = len(tasks)
# 保存任务列表
tasks_file = os.path.join(output_dir, "tasks.json")
with open(tasks_file, 'w') as f:
json.dump(tasks, f, indent=2)
# 生成作业脚本
script_content = self._generate_array_script(
tasks_file=tasks_file,
output_dir=output_dir,
array_range=f"0-{total-1}%{self.config.max_concurrent}"
)
script_path = os.path.join(output_dir, "submit_array.sh")
with open(script_path, 'w') as f:
f.write(script_content)
os.chmod(script_path, 0o755)
print(f"生成作业脚本: {script_path}")
# 提交作业
result = subprocess.run(
['sbatch', script_path],
capture_output=True,
text=True
)
if result.returncode != 0:
print(f"❌ 作业提交失败: {result.stderr}")
return [ZeoTaskResult(
task_id=i,
structure_name=t.get('structure_name', ''),
cif_path=t.get('cif_path', ''),
success=False,
error_message=f"提交失败: {result.stderr}"
) for i, t in enumerate(tasks)]
# 提取作业 ID
job_id = result.stdout.strip().split()[-1]
print(f"✓ 作业已提交: {job_id}")
print(f" 作业数组范围: 0-{total-1}")
print(f" 最大并发: {self.config.max_concurrent}")
# 监控作业进度
return self._monitor_array_job(job_id, tasks, output_dir, desc)
def _submit_batched_arrays(
self,
tasks: List[Dict],
output_dir: str,
desc: str
) -> List[ZeoTaskResult]:
"""分批提交多个作业数组"""
total = len(tasks)
batch_size = self.config.max_array_size
num_batches = (total + batch_size - 1) // batch_size
print(f"任务数超过作业数组限制,分 {num_batches} 批提交")
all_results = []
for batch_idx in range(num_batches):
start_idx = batch_idx * batch_size
end_idx = min(start_idx + batch_size, total)
batch_tasks = tasks[start_idx:end_idx]
batch_output_dir = os.path.join(output_dir, f"batch_{batch_idx}")
os.makedirs(batch_output_dir, exist_ok=True)
print(f"\n--- 批次 {batch_idx + 1}/{num_batches} ---")
print(f"任务范围: {start_idx} - {end_idx - 1}")
batch_results = self._submit_array_job(
batch_tasks,
batch_output_dir,
f"{desc} (批次 {batch_idx + 1}/{num_batches})"
)
# 调整任务 ID
for r in batch_results:
r.task_id += start_idx
all_results.extend(batch_results)
return all_results
def _generate_array_script(
self,
tasks_file: str,
output_dir: str,
array_range: str
) -> str:
"""生成 SLURM 作业数组脚本"""
# 获取项目根目录
project_root = os.getcwd()
script = f"""#!/bin/bash
#SBATCH --job-name=zeo_array
#SBATCH --partition={self.config.partition}
#SBATCH --array={array_range}
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=1
#SBATCH --mem={self.config.memory_per_task}
#SBATCH --time={self.config.time_limit}
#SBATCH --output={output_dir}/task_%a.out
#SBATCH --error={output_dir}/task_%a.err
# ============================================
# Zeo++ Voronoi 分析 - 作业数组
# ============================================
echo "===== 任务信息 ====="
echo "作业ID: $SLURM_JOB_ID"
echo "数组任务ID: $SLURM_ARRAY_TASK_ID"
echo "节点: $SLURM_NODELIST"
echo "开始时间: $(date)"
echo "===================="
# ============ 环境初始化 ============
# 加载 bashrc
if [ -f ~/.bashrc ]; then
source ~/.bashrc
fi
# 初始化 Conda
if [ -f ~/anaconda3/etc/profile.d/conda.sh ]; then
source ~/anaconda3/etc/profile.d/conda.sh
elif [ -f /opt/anaconda3/etc/profile.d/conda.sh ]; then
source /opt/anaconda3/etc/profile.d/conda.sh
fi
# 激活 Zeo++ 环境
conda activate {self.config.conda_env}
echo ""
echo "===== 环境检查 ====="
echo "Conda环境: $CONDA_DEFAULT_ENV"
echo "Python路径: $(which python)"
echo "===================="
echo ""
# ============ 读取任务信息 ============
TASKS_FILE="{tasks_file}"
TASK_ID=$SLURM_ARRAY_TASK_ID
# 使用 Python 解析任务
TASK_INFO=$(python3 -c "
import json
with open('$TASKS_FILE', 'r') as f:
tasks = json.load(f)
if $TASK_ID < len(tasks):
task = tasks[$TASK_ID]
print(task['work_dir'])
print(task['yaml_name'])
else:
print('ERROR')
")
WORK_DIR=$(echo "$TASK_INFO" | sed -n '1p')
YAML_NAME=$(echo "$TASK_INFO" | sed -n '2p')
if [ "$WORK_DIR" == "ERROR" ]; then
echo "错误: 任务ID $TASK_ID 超出范围"
exit 1
fi
echo "工作目录: $WORK_DIR"
echo "配置文件: $YAML_NAME"
echo ""
# ============ 执行计算 ============
cd "$WORK_DIR"
echo "开始 Voronoi 分析..."
# 软链接已在工作目录下,直接使用相对路径
# 将输出重定向到 log.txt 以便后续提取结果
python analyze_voronoi_nodes.py *.cif -i "$YAML_NAME" > log.txt 2>&1
EXIT_CODE=$?
# 显示日志内容(用于调试)
echo ""
echo "===== 计算日志 ====="
cat log.txt
echo "===================="
# ============ 完成 ============
echo ""
echo "===== 任务完成 ====="
echo "结束时间: $(date)"
echo "退出代码: $EXIT_CODE"
# 写入状态文件
if [ $EXIT_CODE -eq 0 ]; then
echo "SUCCESS" > "{output_dir}/status_$TASK_ID.txt"
else
echo "FAILED" > "{output_dir}/status_$TASK_ID.txt"
fi
echo "===================="
exit $EXIT_CODE
"""
return script
def _monitor_array_job(
self,
job_id: str,
tasks: List[Dict],
output_dir: str,
desc: str
) -> List[ZeoTaskResult]:
"""监控作业数组进度"""
total = len(tasks)
self.progress_manager = ProgressManager(total, desc)
self.progress_manager.start()
results = [None] * total
completed = set()
print(f"\n监控作业进度 (每 {self.config.poll_interval} 秒检查一次)...")
print("按 Ctrl+C 可中断监控(作业将继续在后台运行)\n")
try:
while len(completed) < total:
time.sleep(self.config.poll_interval)
# 检查状态文件
for i in range(total):
if i in completed:
continue
status_file = os.path.join(output_dir, f"status_{i}.txt")
if os.path.exists(status_file):
with open(status_file, 'r') as f:
status = f.read().strip()
task = tasks[i]
success = (status == "SUCCESS")
# 收集输出文件
output_files = []
if success:
work_dir = task['work_dir']
for f in os.listdir(work_dir):
if f.endswith(('.cif', '.csv')) and f != task['cif_name']:
output_files.append(os.path.join(work_dir, f))
results[i] = ZeoTaskResult(
task_id=i,
structure_name=task.get('structure_name', ''),
cif_path=task.get('cif_path', ''),
success=success,
output_files=output_files
)
completed.add(i)
self.progress_manager.update(success=success)
self.progress_manager.display()
# 检查作业是否还在运行
if not self._is_job_running(job_id) and len(completed) < total:
# 作业已结束但有任务未完成
print(f"\n⚠️ 作业已结束,但有 {total - len(completed)} 个任务未完成")
break
except KeyboardInterrupt:
print("\n\n⚠️ 监控已中断,作业将继续在后台运行")
print(f" 可使用 'squeue -j {job_id}' 查看作业状态")
print(f" 可使用 'scancel {job_id}' 取消作业")
self.progress_manager.finish()
# 填充未完成的任务
for i in range(total):
if results[i] is None:
task = tasks[i]
results[i] = ZeoTaskResult(
task_id=i,
structure_name=task.get('structure_name', ''),
cif_path=task.get('cif_path', ''),
success=False,
error_message="任务未完成或状态未知"
)
return results
def _is_job_running(self, job_id: str) -> bool:
"""检查作业是否还在运行"""
try:
result = subprocess.run(
['squeue', '-j', job_id, '-h'],
capture_output=True,
text=True,
timeout=10
)
return bool(result.stdout.strip())
except Exception:
return False
def print_results_summary(self, results: List[ZeoTaskResult]):
"""打印结果摘要"""
total = len(results)
success = sum(1 for r in results if r.success)
failed = total - success
print("\n" + "=" * 60)
print("【计算结果摘要】")
print("=" * 60)
print(f" 总任务数: {total}")
print(f" 成功: {success} ({100*success/total:.1f}%)")
print(f" 失败: {failed} ({100*failed/total:.1f}%)")
if failed > 0 and failed <= 10:
print("\n 失败的任务:")
for r in results:
if not r.success:
print(f" - {r.structure_name}: {r.error_message}")
elif failed > 10:
print(f"\n 失败任务过多,请检查日志文件")
print("=" * 60)

18
src/core/__init__.py Normal file
View File

@@ -0,0 +1,18 @@
"""
核心模块:调度器、执行器和进度管理
"""
from .scheduler import ParallelScheduler, ResourceConfig, ExecutionMode as SchedulerMode
from .executor import TaskExecutor, ExecutorConfig, ExecutionMode, TaskResult, create_executor
from .progress import ProgressManager
__all__ = [
'ParallelScheduler',
'ResourceConfig',
'SchedulerMode',
'TaskExecutor',
'ExecutorConfig',
'ExecutionMode',
'TaskResult',
'create_executor',
'ProgressManager',
]

0
src/core/controller.py Normal file
View File

431
src/core/executor.py Normal file
View File

@@ -0,0 +1,431 @@
"""
任务执行器:支持本地执行和 SLURM 直接提交
不生成脚本文件,直接在 Python 中管理任务
"""
import os
import subprocess
import time
import json
from typing import List, Callable, Any, Optional, Dict, Tuple
from multiprocessing import Pool, cpu_count
from dataclasses import dataclass, field
from enum import Enum
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading
from .progress import ProgressManager
class ExecutionMode(Enum):
"""执行模式"""
LOCAL = "local" # 本地多进程
SLURM_DIRECT = "slurm" # SLURM 直接提交(不生成脚本)
@dataclass
class ExecutorConfig:
"""执行器配置"""
mode: ExecutionMode = ExecutionMode.LOCAL
max_workers: int = 4
conda_env: str = "/cluster/home/koko125/anaconda3/envs/screen"
partition: str = "cpu"
time_limit: str = "7-00:00:00"
memory_per_task: str = "4G"
# SLURM 相关
poll_interval: float = 2.0 # 轮询间隔(秒)
max_concurrent_jobs: int = 50 # 最大并发作业数
@dataclass
class TaskResult:
"""任务结果"""
task_id: Any
success: bool
result: Any = None
error: str = None
duration: float = 0.0
class TaskExecutor:
"""
任务执行器
支持两种模式:
1. LOCAL: 本地多进程执行
2. SLURM_DIRECT: 直接提交 SLURM 作业,实时监控进度
"""
def __init__(self, config: ExecutorConfig = None):
self.config = config or ExecutorConfig()
self.progress_manager = None
self._stop_event = threading.Event()
@staticmethod
def detect_environment() -> Dict[str, Any]:
"""检测运行环境"""
env_info = {
'hostname': os.uname().nodename,
'total_cores': cpu_count(),
'has_slurm': False,
'slurm_partitions': [],
'conda_env': os.environ.get('CONDA_PREFIX', ''),
}
# 检测 SLURM
try:
result = subprocess.run(
['sinfo', '-h', '-o', '%P %a %c %D'],
capture_output=True, text=True, timeout=5
)
if result.returncode == 0:
env_info['has_slurm'] = True
lines = result.stdout.strip().split('\n')
for line in lines:
parts = line.split()
if len(parts) >= 4:
partition = parts[0].rstrip('*')
avail = parts[1]
if avail == 'up':
env_info['slurm_partitions'].append(partition)
except Exception:
pass
return env_info
def run(
self,
tasks: List[Any],
worker_func: Callable,
desc: str = "Processing"
) -> List[TaskResult]:
"""
执行任务
Args:
tasks: 任务列表
worker_func: 工作函数,接收单个任务,返回结果
desc: 进度条描述
Returns:
TaskResult 列表
"""
if self.config.mode == ExecutionMode.LOCAL:
return self._run_local(tasks, worker_func, desc)
elif self.config.mode == ExecutionMode.SLURM_DIRECT:
return self._run_slurm_direct(tasks, worker_func, desc)
else:
raise ValueError(f"不支持的执行模式: {self.config.mode}")
def _run_local(
self,
tasks: List[Any],
worker_func: Callable,
desc: str
) -> List[TaskResult]:
"""本地多进程执行"""
total = len(tasks)
num_workers = min(self.config.max_workers, total)
print(f"\n{'='*60}")
print(f"本地执行配置:")
print(f" 总任务数: {total}")
print(f" Worker数: {num_workers}")
print(f"{'='*60}\n")
self.progress_manager = ProgressManager(total, desc)
self.progress_manager.start()
results = []
if num_workers == 1:
# 单进程执行
for i, task in enumerate(tasks):
start_time = time.time()
try:
result = worker_func(task)
duration = time.time() - start_time
results.append(TaskResult(
task_id=i,
success=True,
result=result,
duration=duration
))
self.progress_manager.update(success=True)
except Exception as e:
duration = time.time() - start_time
results.append(TaskResult(
task_id=i,
success=False,
error=str(e),
duration=duration
))
self.progress_manager.update(success=False)
self.progress_manager.display()
else:
# 多进程执行
with Pool(processes=num_workers) as pool:
for i, result in enumerate(pool.imap_unordered(worker_func, tasks)):
if result is not None:
results.append(TaskResult(
task_id=i,
success=True,
result=result
))
self.progress_manager.update(success=True)
else:
results.append(TaskResult(
task_id=i,
success=False,
error="Worker returned None"
))
self.progress_manager.update(success=False)
self.progress_manager.display()
self.progress_manager.finish()
return results
def _run_slurm_direct(
self,
tasks: List[Any],
worker_func: Callable,
desc: str
) -> List[TaskResult]:
"""
SLURM 直接提交模式
注意:对于数据库分析这类快速任务,建议使用本地多进程模式
SLURM 模式更适合耗时的计算任务(如 Zeo++ 分析)
这里回退到本地模式,因为 srun 在登录节点直接调用效率不高
"""
print("\n⚠️ 注意:数据库分析阶段自动使用本地多进程模式")
print(" SLURM 模式将在后续耗时计算步骤中使用")
# 回退到本地模式
return self._run_local(tasks, worker_func, desc)
class SlurmJobManager:
"""
SLURM 作业管理器
用于批量提交和监控 SLURM 作业
"""
def __init__(self, config: ExecutorConfig):
self.config = config
self.active_jobs = {} # job_id -> task_info
def submit_batch(
self,
tasks: List[Tuple[str, str, set]], # (file_path, target_cation, target_anions)
output_dir: str,
desc: str = "Processing"
) -> List[TaskResult]:
"""
批量提交任务到 SLURM
使用 sbatch --wrap 直接提交,不生成脚本文件
"""
total = len(tasks)
os.makedirs(output_dir, exist_ok=True)
print(f"\n{'='*60}")
print(f"SLURM 批量提交:")
print(f" 总任务数: {total}")
print(f" 输出目录: {output_dir}")
print(f" Conda环境: {self.config.conda_env}")
print(f"{'='*60}\n")
progress = ProgressManager(total, desc)
progress.start()
results = []
job_ids = []
# 提交所有任务
for i, task in enumerate(tasks):
file_path, target_cation, target_anions = task
# 构建 Python 命令
anions_str = ','.join(target_anions)
python_cmd = (
f"python -c \""
f"import sys; sys.path.insert(0, '{os.getcwd()}'); "
f"from src.analysis.worker import analyze_single_file; "
f"result = analyze_single_file(('{file_path}', '{target_cation}', set('{anions_str}'.split(',')))); "
f"print('SUCCESS' if result and result.is_valid else 'FAILED')"
f"\""
)
# 构建完整的 bash 命令
bash_cmd = (
f"source {os.path.dirname(self.config.conda_env)}/../../etc/profile.d/conda.sh && "
f"conda activate {self.config.conda_env} && "
f"{python_cmd}"
)
# 使用 sbatch --wrap 提交
sbatch_cmd = [
'sbatch',
'--partition', self.config.partition,
'--ntasks', '1',
'--cpus-per-task', '1',
'--mem', self.config.memory_per_task,
'--time', '01:00:00',
'--output', os.path.join(output_dir, f'task_{i}.out'),
'--error', os.path.join(output_dir, f'task_{i}.err'),
'--wrap', bash_cmd
]
try:
result = subprocess.run(
sbatch_cmd,
capture_output=True,
text=True
)
if result.returncode == 0:
# 提取 job_id
job_id = result.stdout.strip().split()[-1]
job_ids.append((i, job_id, file_path))
self.active_jobs[job_id] = {
'task_index': i,
'file_path': file_path,
'status': 'PENDING'
}
else:
results.append(TaskResult(
task_id=i,
success=False,
error=f"提交失败: {result.stderr}"
))
progress.update(success=False)
progress.display()
except Exception as e:
results.append(TaskResult(
task_id=i,
success=False,
error=str(e)
))
progress.update(success=False)
progress.display()
print(f"\n已提交 {len(job_ids)} 个作业,等待完成...")
# 监控作业状态
while self.active_jobs:
time.sleep(self.config.poll_interval)
# 检查作业状态
completed_jobs = self._check_job_status()
for job_id, status in completed_jobs:
job_info = self.active_jobs.pop(job_id, None)
if job_info:
task_idx = job_info['task_index']
if status == 'COMPLETED':
# 检查输出文件
out_file = os.path.join(output_dir, f'task_{task_idx}.out')
success = False
if os.path.exists(out_file):
with open(out_file, 'r') as f:
content = f.read()
success = 'SUCCESS' in content
results.append(TaskResult(
task_id=task_idx,
success=success,
result=job_info['file_path']
))
progress.update(success=success)
else:
# 作业失败
err_file = os.path.join(output_dir, f'task_{task_idx}.err')
error_msg = status
if os.path.exists(err_file):
with open(err_file, 'r') as f:
error_msg = f.read()[:500] # 只取前500字符
results.append(TaskResult(
task_id=task_idx,
success=False,
error=error_msg
))
progress.update(success=False)
progress.display()
progress.finish()
return results
def _check_job_status(self) -> List[Tuple[str, str]]:
"""检查作业状态,返回已完成的作业列表"""
if not self.active_jobs:
return []
job_ids = list(self.active_jobs.keys())
try:
result = subprocess.run(
['sacct', '-j', ','.join(job_ids), '--format=JobID,State', '--noheader', '--parsable2'],
capture_output=True,
text=True,
timeout=30
)
completed = []
if result.returncode == 0:
for line in result.stdout.strip().split('\n'):
if line:
parts = line.split('|')
if len(parts) >= 2:
job_id = parts[0].split('.')[0] # 去掉 .batch 后缀
status = parts[1]
if job_id in self.active_jobs:
if status in ['COMPLETED', 'FAILED', 'CANCELLED', 'TIMEOUT', 'NODE_FAIL']:
completed.append((job_id, status))
return completed
except Exception:
return []
def create_executor(
mode: str = "local",
max_workers: int = None,
conda_env: str = None,
**kwargs
) -> TaskExecutor:
"""
创建任务执行器的便捷函数
Args:
mode: "local""slurm"
max_workers: 最大工作进程数
conda_env: Conda 环境路径
**kwargs: 其他配置参数
"""
env = TaskExecutor.detect_environment()
if max_workers is None:
max_workers = min(env['total_cores'], 32)
if conda_env is None:
conda_env = env.get('conda_env') or "/cluster/home/koko125/anaconda3/envs/screen"
exec_mode = ExecutionMode.SLURM_DIRECT if mode.lower() == "slurm" else ExecutionMode.LOCAL
config = ExecutorConfig(
mode=exec_mode,
max_workers=max_workers,
conda_env=conda_env,
**kwargs
)
return TaskExecutor(config)

115
src/core/progress.py Normal file
View File

@@ -0,0 +1,115 @@
"""
进度管理器:支持多进程的实时进度显示
"""
import os
import sys
import time
from multiprocessing import Manager, Value
from typing import Optional
from datetime import datetime, timedelta
class ProgressManager:
"""多进程安全的进度管理器"""
def __init__(self, total: int, desc: str = "Processing"):
self.total = total
self.desc = desc
self.manager = Manager()
self.completed = self.manager.Value('i', 0)
self.failed = self.manager.Value('i', 0)
self.start_time = None
self._lock = self.manager.Lock()
def start(self):
"""开始计时"""
self.start_time = time.time()
def update(self, success: bool = True):
"""更新进度(进程安全)"""
with self._lock:
if success:
self.completed.value += 1
else:
self.failed.value += 1
def get_progress(self) -> dict:
"""获取当前进度"""
completed = self.completed.value
failed = self.failed.value
total_done = completed + failed
elapsed = time.time() - self.start_time if self.start_time else 0
if total_done > 0:
speed = total_done / elapsed # items/sec
remaining = (self.total - total_done) / speed if speed > 0 else 0
else:
speed = 0
remaining = 0
return {
'total': self.total,
'completed': completed,
'failed': failed,
'total_done': total_done,
'percent': total_done / self.total * 100 if self.total > 0 else 0,
'elapsed': elapsed,
'remaining': remaining,
'speed': speed
}
def display(self):
"""显示进度条"""
p = self.get_progress()
# 进度条
bar_width = 40
filled = int(bar_width * p['total_done'] / p['total']) if p['total'] > 0 else 0
bar = '' * filled + '' * (bar_width - filled)
# 时间格式化
elapsed_str = str(timedelta(seconds=int(p['elapsed'])))
remaining_str = str(timedelta(seconds=int(p['remaining'])))
# 构建显示字符串
status = (
f"\r{self.desc}: |{bar}| "
f"{p['total_done']}/{p['total']} ({p['percent']:.1f}%) "
f"[{elapsed_str}<{remaining_str}, {p['speed']:.1f}it/s] "
f"{p['completed']}{p['failed']}"
)
sys.stdout.write(status)
sys.stdout.flush()
def finish(self):
"""完成显示"""
self.display()
print() # 换行
class ProgressReporter:
"""进度报告器(用于后台监控)"""
def __init__(self, progress_file: str = ".progress"):
self.progress_file = progress_file
def save(self, progress: dict):
"""保存进度到文件"""
import json
with open(self.progress_file, 'w') as f:
json.dump(progress, f)
def load(self) -> Optional[dict]:
"""从文件加载进度"""
import json
if os.path.exists(self.progress_file):
with open(self.progress_file, 'r') as f:
return json.load(f)
return None
def cleanup(self):
"""清理进度文件"""
if os.path.exists(self.progress_file):
os.remove(self.progress_file)

291
src/core/scheduler.py Normal file
View File

@@ -0,0 +1,291 @@
"""
并行调度器支持本地多进程和SLURM集群调度
"""
import os
import subprocess
import tempfile
from typing import List, Callable, Any, Optional, Dict
from multiprocessing import Pool, cpu_count
from dataclasses import dataclass
from enum import Enum
import math
from .progress import ProgressManager
class ExecutionMode(Enum):
"""执行模式"""
LOCAL_SINGLE = "local_single"
LOCAL_MULTI = "local_multi"
SLURM_SINGLE = "slurm_single"
SLURM_ARRAY = "slurm_array"
@dataclass
class ResourceConfig:
"""资源配置"""
max_cores: int = 4
cores_per_worker: int = 1
memory_per_core: str = "4G"
partition: str = "cpu"
time_limit: str = "7-00:00:00"
conda_env_path: str = "~/anaconda3/envs/screen" # 新增Conda环境路径
@property
def num_workers(self) -> int:
return max(1, self.max_cores // self.cores_per_worker)
class ParallelScheduler:
"""并行调度器"""
COMPLEXITY_CORES = {
'low': 1,
'medium': 2,
'high': 4,
}
def __init__(self, resource_config: ResourceConfig = None):
self.config = resource_config or ResourceConfig()
self.progress_manager = None
@staticmethod
def detect_environment() -> Dict[str, Any]:
"""检测运行环境"""
env_info = {
'hostname': os.uname().nodename,
'total_cores': cpu_count(),
'has_slurm': False,
'slurm_partitions': [],
'available_nodes': 0,
'conda_prefix': os.environ.get('CONDA_PREFIX', ''),
}
# 检测SLURM
try:
result = subprocess.run(
['sinfo', '-h', '-o', '%P %a %c %D'],
capture_output=True, text=True, timeout=5
)
if result.returncode == 0:
env_info['has_slurm'] = True
lines = result.stdout.strip().split('\n')
for line in lines:
parts = line.split()
if len(parts) >= 4:
partition = parts[0].rstrip('*')
avail = parts[1]
cores = int(parts[2])
nodes = int(parts[3])
if avail == 'up':
env_info['slurm_partitions'].append({
'name': partition,
'cores_per_node': cores,
'nodes': nodes
})
env_info['available_nodes'] += nodes
except Exception:
pass
return env_info
@staticmethod
def recommend_config(
num_tasks: int,
task_complexity: str = 'medium',
max_cores: int = None,
conda_env_path: str = None
) -> ResourceConfig:
"""根据任务量和复杂度推荐配置"""
env = ParallelScheduler.detect_environment()
if max_cores is None:
max_cores = min(env['total_cores'], 32)
cores_per_worker = ParallelScheduler.COMPLEXITY_CORES.get(task_complexity, 2)
max_workers = max_cores // cores_per_worker
optimal_workers = min(num_tasks, max_workers)
actual_cores = optimal_workers * cores_per_worker
# 自动检测Conda环境路径
if conda_env_path is None:
conda_env_path = env.get('conda_prefix', '~/anaconda3/envs/screen')
config = ResourceConfig(
max_cores=actual_cores,
cores_per_worker=cores_per_worker,
conda_env_path=conda_env_path,
)
return config
def run_local(
self,
tasks: List[Any],
worker_func: Callable,
desc: str = "Processing"
) -> List[Any]:
"""本地多进程执行"""
num_workers = self.config.num_workers
total = len(tasks)
print(f"\n{'='*60}")
print(f"并行配置:")
print(f" 总任务数: {total}")
print(f" Worker数: {num_workers}")
print(f" 每Worker核数: {self.config.cores_per_worker}")
print(f" 总使用核数: {self.config.max_cores}")
print(f"{'='*60}\n")
self.progress_manager = ProgressManager(total, desc)
self.progress_manager.start()
results = []
if num_workers == 1:
for task in tasks:
try:
result = worker_func(task)
results.append(result)
self.progress_manager.update(success=True)
except Exception as e:
results.append(None)
self.progress_manager.update(success=False)
self.progress_manager.display()
else:
with Pool(processes=num_workers) as pool:
for result in pool.imap_unordered(worker_func, tasks, chunksize=10):
if result is not None:
results.append(result)
self.progress_manager.update(success=True)
else:
self.progress_manager.update(success=False)
self.progress_manager.display()
self.progress_manager.finish()
return results
def generate_slurm_script(
self,
tasks_file: str,
worker_script: str,
output_dir: str,
job_name: str = "analysis"
) -> str:
"""生成SLURM作业脚本修复版"""
# 获取当前工作目录的绝对路径
submit_dir = os.getcwd()
# 转换为绝对路径
abs_tasks_file = os.path.abspath(tasks_file)
abs_worker_script = os.path.abspath(worker_script)
abs_output_dir = os.path.abspath(output_dir)
script = f"""#!/bin/bash
#SBATCH --job-name={job_name}
#SBATCH --partition={self.config.partition}
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH --cpus-per-task={self.config.max_cores}
#SBATCH --mem-per-cpu={self.config.memory_per_core}
#SBATCH --time={self.config.time_limit}
#SBATCH --output={abs_output_dir}/slurm_%j.log
#SBATCH --error={abs_output_dir}/slurm_%j.err
# ============================================
# SLURM作业脚本 - 自动生成
# ============================================
echo "===== 作业信息 ====="
echo "作业ID: $SLURM_JOB_ID"
echo "节点: $SLURM_NODELIST"
echo "CPU数: $SLURM_CPUS_PER_TASK"
echo "开始时间: $(date)"
echo "===================="
# ============ 环境初始化 ============
# 关键确保bashrc被加载
if [ -f ~/.bashrc ]; then
source ~/.bashrc
fi
# 初始化Conda
if [ -f ~/anaconda3/etc/profile.d/conda.sh ]; then
source ~/anaconda3/etc/profile.d/conda.sh
elif [ -f /opt/anaconda3/etc/profile.d/conda.sh ]; then
source /opt/anaconda3/etc/profile.d/conda.sh
else
echo "错误: 找不到conda.sh"
exit 1
fi
# 激活环境 (使用完整路径)
conda activate {self.config.conda_env_path}
# 验证环境
echo ""
echo "===== 环境检查 ====="
echo "Conda环境: $CONDA_DEFAULT_ENV"
echo "Python路径: $(which python)"
echo "Python版本: $(python --version 2>&1)"
python -c "import pymatgen; print(f'pymatgen版本: {{pymatgen.__version__}}')" 2>/dev/null || echo "警告: pymatgen未安装"
echo "===================="
echo ""
# 设置工作目录
cd {submit_dir}
export PYTHONPATH={submit_dir}:$PYTHONPATH
echo "工作目录: $(pwd)"
echo "PYTHONPATH: $PYTHONPATH"
echo ""
# ============ 运行分析 ============
echo "开始执行分析任务..."
python {abs_worker_script} \\
--tasks-file {abs_tasks_file} \\
--output-dir {abs_output_dir} \\
--num-workers {self.config.num_workers}
EXIT_CODE=$?
# ============ 完成 ============
echo ""
echo "===== 作业完成 ====="
echo "结束时间: $(date)"
echo "退出代码: $EXIT_CODE"
echo "===================="
exit $EXIT_CODE
"""
return script
def submit_slurm_job(self, script_content: str, script_path: str = None) -> str:
"""提交SLURM作业"""
if script_path is None:
fd, script_path = tempfile.mkstemp(suffix='.sh')
os.close(fd)
with open(script_path, 'w') as f:
f.write(script_content)
os.chmod(script_path, 0o755)
# 打印脚本内容用于调试
print(f"\n生成的SLURM脚本保存到: {script_path}")
result = subprocess.run(
['sbatch', script_path],
capture_output=True, text=True
)
if result.returncode == 0:
job_id = result.stdout.strip().split()[-1]
return job_id
else:
raise RuntimeError(f"SLURM提交失败:\nstdout: {result.stdout}\nstderr: {result.stderr}")

View File

View File

View File

View File

@@ -0,0 +1,562 @@
"""
结构预处理器:扩胞和添加化合价
"""
import os
import re
import yaml
from typing import List, Dict, Optional, Tuple
from dataclasses import dataclass, field
from pymatgen.core.structure import Structure
from pymatgen.core.periodic_table import Specie
from pymatgen.core import Lattice, Species, PeriodicSite
from collections import defaultdict
from fractions import Fraction
from functools import reduce
import math
import random
import spglib
import numpy as np
@dataclass
class ProcessingResult:
"""处理结果"""
input_file: str
output_files: List[str] = field(default_factory=list)
success: bool = False
needs_expansion: bool = False
expansion_factor: int = 1
error_message: str = ""
class StructureProcessor:
"""结构预处理器"""
# 默认化合价配置
DEFAULT_VALENCE_PATH = os.path.join(
os.path.dirname(__file__), '..', '..', 'tool', 'valence_states.yaml'
)
def __init__(
self,
valence_yaml_path: str = None,
calculate_type: str = 'low',
max_expansion_factor: int = 64,
keep_number: int = 3,
target_cation: str = "Li"
):
"""
初始化处理器
Args:
valence_yaml_path: 化合价配置文件路径
calculate_type: 扩胞计算精度 ('high', 'normal', 'low', 'very_low')
max_expansion_factor: 最大扩胞因子
keep_number: 保留的扩胞结构数量
target_cation: 目标阳离子
"""
self.valence_yaml_path = valence_yaml_path or self.DEFAULT_VALENCE_PATH
self.calculate_type = calculate_type
self.max_expansion_factor = max_expansion_factor
self.keep_number = keep_number
self.target_cation = target_cation
self.explict_element = [target_cation, f"{target_cation}+"]
# 加载化合价配置
self.valences = self._load_valences()
def _load_valences(self) -> Dict[str, int]:
"""加载化合价配置"""
if os.path.exists(self.valence_yaml_path):
with open(self.valence_yaml_path, 'r') as f:
return yaml.safe_load(f)
return {}
def process_file(
self,
input_path: str,
output_dir: str,
needs_expansion: bool = False
) -> ProcessingResult:
"""
处理单个CIF文件
Args:
input_path: 输入文件路径
output_dir: 输出目录
needs_expansion: 是否需要扩胞
Returns:
ProcessingResult: 处理结果
"""
result = ProcessingResult(input_file=input_path)
try:
# 读取结构
structure = Structure.from_file(input_path)
base_name = os.path.splitext(os.path.basename(input_path))[0]
# 检查是否需要扩胞
occupation_list = self._process_cif_file(structure)
if occupation_list and needs_expansion:
# 需要扩胞处理
result.needs_expansion = True
output_files = self._expand_and_save(
structure, occupation_list, base_name, output_dir
)
result.output_files = output_files
result.expansion_factor = occupation_list[0].get('denominator', 1) if occupation_list else 1
else:
# 不需要扩胞,直接添加化合价
output_path = os.path.join(output_dir, f"{base_name}.cif")
self._add_oxidation_states(structure)
structure.to(filename=output_path)
result.output_files = [output_path]
result.success = True
except Exception as e:
result.success = False
result.error_message = str(e)
return result
def _process_cif_file(self, structure: Structure) -> List[Dict]:
"""
统计结构中各原子的occupation情况
"""
occupation_dict = defaultdict(list)
split_dict = {}
for i, site in enumerate(structure):
occu = self._get_occu(site.species_string)
if occu != 1.0:
if site.species.chemical_system not in self.explict_element:
occupation_dict[occu].append(i + 1)
# 提取元素名称列表
elements = []
if ':' in site.species_string:
parts = site.species_string.split(',')
for part in parts:
element_with_valence = part.strip().split(':')[0].strip()
element_match = re.match(r'([A-Z][a-z]?)', element_with_valence)
if element_match:
elements.append(element_match.group(1))
else:
element_match = re.match(r'([A-Z][a-z]?)', site.species_string)
if element_match:
elements = [element_match.group(1)]
split_dict[occu] = elements
# 转换为列表格式
occupation_list = [
{
"occupation": occu,
"atom_serial": serials,
"numerator": None,
"denominator": None,
"split": split_dict.get(occu, [])
}
for occu, serials in occupation_dict.items()
]
return occupation_list
def _get_occu(self, s_str: str) -> float:
"""从物种字符串获取占据率"""
if not s_str.strip():
return 1.0
pattern = r'([A-Za-z0-9+-]+):([0-9.]+)'
matches = re.findall(pattern, s_str)
for species, occu in matches:
if species not in self.explict_element:
try:
return float(occu)
except ValueError:
continue
return 1.0
def _calculate_expansion_factor(self, occupation_list: List[Dict]) -> Tuple[int, List[Dict]]:
"""计算扩胞因子"""
if not occupation_list:
return 1, []
precision_limits = {
'high': None,
'normal': 100,
'low': 10,
'very_low': 5
}
limit = precision_limits.get(self.calculate_type)
for entry in occupation_list:
occu = entry["occupation"]
if limit:
fraction = Fraction(occu).limit_denominator(limit)
else:
fraction = Fraction(occu).limit_denominator()
entry["numerator"] = fraction.numerator
entry["denominator"] = fraction.denominator
# 计算最小公倍数
denominators = [entry["denominator"] for entry in occupation_list]
lcm = reduce(lambda a, b: a * b // math.gcd(a, b), denominators, 1)
# 统一分母
for entry in occupation_list:
denominator = entry["denominator"]
entry["numerator"] = entry["numerator"] * (lcm // denominator)
entry["denominator"] = lcm
return lcm, occupation_list
def _expand_and_save(
self,
structure: Structure,
occupation_list: List[Dict],
base_name: str,
output_dir: str
) -> List[str]:
"""扩胞并保存"""
lcm, oc_list = self._calculate_expansion_factor(occupation_list)
if lcm > self.max_expansion_factor:
raise ValueError(f"扩胞因子 {lcm} 超过最大限制 {self.max_expansion_factor}")
# 获取扩胞策略
strategies = self._strategy_divide(structure, lcm)
if not strategies:
raise ValueError("无法找到合适的扩胞策略")
# 生成结构列表
st_list = self._generate_structure_list(structure, oc_list)
output_files = []
keep_number = min(self.keep_number, len(strategies))
for index in range(keep_number):
merged = self._merge_structures(st_list, strategies[index])
# 添加化合价
self._add_oxidation_states(merged)
# 当只保存1个时不加后缀
if keep_number == 1:
output_filename = f"{base_name}.cif"
else:
suffix = "x{}y{}z{}".format(
strategies[index]["x"],
strategies[index]["y"],
strategies[index]["z"]
)
output_filename = f"{base_name}-{suffix}.cif"
output_path = os.path.join(output_dir, output_filename)
merged.to(filename=output_path, fmt="cif")
output_files.append(output_path)
return output_files
def _add_oxidation_states(self, structure: Structure):
"""添加化合价"""
# 检查是否已有化合价
has_oxidation = all(
all(isinstance(sp, Specie) for sp in site.species.keys())
for site in structure.sites
)
if not has_oxidation and self.valences:
structure.add_oxidation_state_by_element(self.valences)
def _strategy_divide(self, structure: Structure, total: int) -> List[Dict]:
"""根据晶体类型确定扩胞策略"""
try:
space_group_info = structure.get_space_group_info()
space_group_symbol = space_group_info[0]
# 获取空间群类型
all_spacegroup_symbols = [spglib.get_spacegroup_type(i) for i in range(1, 531)]
symbol = all_spacegroup_symbols[0]
for symbol_i in all_spacegroup_symbols:
if space_group_symbol == symbol_i.international_short:
symbol = symbol_i
break
space_type = self._typejudge(symbol.number)
if space_type == "Cubic":
return self._factorize_to_three_factors(total, "xyz")
else:
return self._factorize_to_three_factors(total)
except:
return self._factorize_to_three_factors(total)
def _typejudge(self, number: int) -> str:
"""判断晶体类型"""
if number in [1, 2]:
return "Triclinic"
elif 3 <= number <= 15:
return "Monoclinic"
elif 16 <= number <= 74:
return "Orthorhombic"
elif 75 <= number <= 142:
return "Tetragonal"
elif 143 <= number <= 167:
return "Trigonal"
elif 168 <= number <= 194:
return "Hexagonal"
elif 195 <= number <= 230:
return "Cubic"
else:
return "Unknown"
def _factorize_to_three_factors(self, n: int, type_sym: str = None) -> List[Dict]:
"""分解为三个因子"""
factors = []
if type_sym == "xyz":
for x in range(1, n + 1):
if n % x == 0:
remaining_n = n // x
for y in range(1, remaining_n + 1):
if remaining_n % y == 0 and y <= x:
z = remaining_n // y
if z <= y:
factors.append({'x': x, 'y': y, 'z': z})
else:
for x in range(1, n + 1):
if n % x == 0:
remaining_n = n // x
for y in range(1, remaining_n + 1):
if remaining_n % y == 0:
z = remaining_n // y
factors.append({'x': x, 'y': y, 'z': z})
# 排序
def sort_key(item):
return (item['x'] + item['y'] + item['z'], item['z'], item['y'], item['x'])
return sorted(factors, key=sort_key)
def _generate_structure_list(
self,
base_structure: Structure,
occupation_list: List[Dict]
) -> List[Structure]:
"""生成结构列表"""
if not occupation_list:
return [base_structure.copy()]
lcm = occupation_list[0]["denominator"]
structure_list = [base_structure.copy() for _ in range(lcm)]
for entry in occupation_list:
numerator = entry["numerator"]
denominator = entry["denominator"]
atom_indices = entry["atom_serial"]
for atom_idx in atom_indices:
occupancy_dict = self._mark_atoms_randomly(numerator, denominator)
original_site = base_structure.sites[atom_idx - 1]
element = self._get_first_non_explicit_element(original_site.species_string)
for copy_idx, occupy in occupancy_dict.items():
structure_list[copy_idx].remove_sites([atom_idx - 1])
oxi_state = self._extract_oxi_state(original_site.species_string, element)
if len(entry["split"]) == 1:
if occupy:
new_site = PeriodicSite(
species=Species(element, oxi_state),
coords=original_site.frac_coords,
lattice=structure_list[copy_idx].lattice,
to_unit_cell=True,
label=original_site.label
)
else:
species_dict = {Species(self.target_cation, 1.0): 0.0}
new_site = PeriodicSite(
species=species_dict,
coords=original_site.frac_coords,
lattice=structure_list[copy_idx].lattice,
to_unit_cell=True,
label=original_site.label
)
else:
if occupy:
new_site = PeriodicSite(
species=Species(element, oxi_state),
coords=original_site.frac_coords,
lattice=structure_list[copy_idx].lattice,
to_unit_cell=True,
label=original_site.label
)
else:
new_site = PeriodicSite(
species=Species(entry['split'][1], oxi_state),
coords=original_site.frac_coords,
lattice=structure_list[copy_idx].lattice,
to_unit_cell=True,
label=original_site.label
)
structure_list[copy_idx].sites.insert(atom_idx - 1, new_site)
return structure_list
def _mark_atoms_randomly(self, numerator: int, denominator: int) -> Dict[int, int]:
"""随机标记原子"""
if numerator > denominator:
raise ValueError(f"numerator ({numerator}) 不能超过 denominator ({denominator})")
atom_dice = list(range(denominator))
selected_atoms = random.sample(atom_dice, numerator)
return {atom: 1 if atom in selected_atoms else 0 for atom in atom_dice}
def _get_first_non_explicit_element(self, species_str: str) -> str:
"""获取第一个非目标元素"""
if not species_str.strip():
return ""
species_parts = [part.strip() for part in species_str.split(",") if part.strip()]
for part in species_parts:
element_with_charge = part.split(":")[0].strip()
pure_element = ''.join([c for c in element_with_charge if c.isalpha()])
if pure_element not in self.explict_element:
return pure_element
return ""
def _extract_oxi_state(self, species_str: str, element: str) -> int:
"""提取氧化态"""
species_parts = [part.strip() for part in species_str.split(",") if part.strip()]
for part in species_parts:
element_with_charge = part.split(":")[0].strip()
if element in element_with_charge:
charge_part = element_with_charge[len(element):]
if not any(c.isdigit() for c in charge_part):
if "+" in charge_part:
return 1
elif "-" in charge_part:
return -1
else:
return 0
sign = 1
if "-" in charge_part:
sign = -1
digits = ""
for c in charge_part:
if c.isdigit():
digits += c
if digits:
return sign * int(digits)
return 0
def _merge_structures(self, structure_list: List[Structure], merge_dict: Dict) -> Structure:
"""合并结构"""
if not structure_list:
raise ValueError("结构列表不能为空")
ref_lattice = structure_list[0].lattice
total_merge = merge_dict.get("x", 1) * merge_dict.get("y", 1) * merge_dict.get("z", 1)
if len(structure_list) != total_merge:
raise ValueError(f"结构数量({len(structure_list)})与合并次数({total_merge})不匹配")
a, b, c = ref_lattice.abc
alpha, beta, gamma = ref_lattice.angles
new_a = a * merge_dict.get("x", 1)
new_b = b * merge_dict.get("y", 1)
new_c = c * merge_dict.get("z", 1)
new_lattice = Lattice.from_parameters(new_a, new_b, new_c, alpha, beta, gamma)
all_sites = []
for i, structure in enumerate(structure_list):
x_offset = (i // (merge_dict.get("y", 1) * merge_dict.get("z", 1))) % merge_dict.get("x", 1)
y_offset = (i // merge_dict.get("z", 1)) % merge_dict.get("y", 1)
z_offset = i % merge_dict.get("z", 1)
for site in structure:
coords = site.frac_coords.copy()
coords[0] = (coords[0] + x_offset) / merge_dict.get("x", 1)
coords[1] = (coords[1] + y_offset) / merge_dict.get("y", 1)
coords[2] = (coords[2] + z_offset) / merge_dict.get("z", 1)
all_sites.append({"species": site.species, "coords": coords})
return Structure(
new_lattice,
[site["species"] for site in all_sites],
[site["coords"] for site in all_sites]
)
def process_batch(
input_files: List[str],
output_dir: str,
needs_expansion_flags: List[bool] = None,
valence_yaml_path: str = None,
calculate_type: str = 'low',
target_cation: str = "Li",
show_progress: bool = True
) -> List[ProcessingResult]:
"""
批量处理CIF文件
Args:
input_files: 输入文件列表
output_dir: 输出目录
needs_expansion_flags: 是否需要扩胞的标记列表
valence_yaml_path: 化合价配置文件路径
calculate_type: 扩胞计算精度
target_cation: 目标阳离子
show_progress: 是否显示进度
Returns:
处理结果列表
"""
os.makedirs(output_dir, exist_ok=True)
processor = StructureProcessor(
valence_yaml_path=valence_yaml_path,
calculate_type=calculate_type,
target_cation=target_cation
)
if needs_expansion_flags is None:
needs_expansion_flags = [False] * len(input_files)
results = []
total = len(input_files)
for i, (input_file, needs_exp) in enumerate(zip(input_files, needs_expansion_flags)):
if show_progress:
print(f"\r处理进度: {i+1}/{total} - {os.path.basename(input_file)}", end="")
result = processor.process_file(input_file, output_dir, needs_exp)
results.append(result)
if show_progress:
print()
return results

View File

View File

0
src/utils/__init__.py Normal file
View File

0
src/utils/io.py Normal file
View File

0
src/utils/logger.py Normal file
View File

0
src/utils/structure.py Normal file
View File

5
tool/Li/Br+O.yaml Normal file
View File

@@ -0,0 +1,5 @@
SPECIE: Li+
ANION: Br
PERCO_R: 0.45
NEIGHBOR: 1.8
LONG: 2.2

5
tool/Li/Cl+O.yaml Normal file
View File

@@ -0,0 +1,5 @@
SPECIE: Li+
ANION: Cl
PERCO_R: 0.45
NEIGHBOR: 1.8
LONG: 2.2