一阶段高筛制作完成

This commit is contained in:
koko
2025-12-16 11:36:49 +08:00
parent 6ea96c81d6
commit f78298e803
22 changed files with 3276 additions and 223 deletions

627
main.py
View File

@@ -1,45 +1,188 @@
"""
高通量筛选与扩胞项目 - 主入口(支持并行
高通量筛选与扩胞项目 - 主入口(支持断点续做
"""
import os
import sys
import json
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
from src.analysis.database_analyzer import DatabaseAnalyzer
from src.analysis.report_generator import ReportGenerator
from src.core.scheduler import ParallelScheduler
from src.core.executor import TaskExecutor
from src.preprocessing.processor import StructureProcessor
from src.computation.workspace_manager import WorkspaceManager
from src.computation.zeo_executor import ZeoExecutor, ZeoConfig
from src.computation.result_processor import ResultProcessor, FilterCriteria
def print_banner():
print("""
╔═══════════════════════════════════════════════════════════════════╗
║ 高通量筛选与扩胞项目 - 数据库分析工具 v2.0
支持高性能并行计算
║ 高通量筛选与扩胞项目 - 数据库分析工具 v2.2
支持断点续做与高性能并行计算
╚═══════════════════════════════════════════════════════════════════╝
""")
def detect_and_show_environment():
"""检测并显示环境信息"""
env = ParallelScheduler.detect_environment()
env = TaskExecutor.detect_environment()
print("【运行环境检测】")
print(f" 主机名: {env['hostname']}")
print(f" 本地CPU核数: {env['total_cores']}")
print(f" SLURM集群: {'✅ 可用' if env['has_slurm'] else '❌ 不可用'}")
if env['has_slurm'] and env['slurm_partitions']:
print(f" 可用分区:")
for p in env['slurm_partitions']:
print(f" - {p['name']}: {p['nodes']}节点, {p['cores_per_node']}核/节点")
print(f" 可用分区: {', '.join(env['slurm_partitions'])}")
if env['conda_env']:
print(f" 当前Conda: {env['conda_env']}")
return env
def get_user_input(env: dict):
"""获取用户输入"""
def detect_workflow_status(workspace_path: str = "workspace", target_cation: str = "Li") -> dict:
"""
检测工作流程状态,确定可以从哪一步继续
Returns:
dict: {
'has_processed_data': bool, # 是否有扩胞处理后的数据
'has_zeo_results': bool, # 是否有 Zeo++ 计算结果
'total_structures': int, # 总结构数
'structures_with_log': int, # 有 log.txt 的结构数
'workspace_info': object, # 工作区信息
'ws_manager': object # 工作区管理器
}
"""
status = {
'has_processed_data': False,
'has_zeo_results': False,
'total_structures': 0,
'structures_with_log': 0,
'workspace_info': None,
'ws_manager': None
}
# 创建工作区管理器
ws_manager = WorkspaceManager(
workspace_path=workspace_path,
tool_dir="tool",
target_cation=target_cation
)
status['ws_manager'] = ws_manager
# 检查是否有处理后的数据
data_dir = os.path.join(workspace_path, "data")
if os.path.exists(data_dir):
existing = ws_manager.check_existing_workspace()
if existing and existing.total_structures > 0:
status['has_processed_data'] = True
status['total_structures'] = existing.total_structures
status['workspace_info'] = existing
# 检查有多少结构有 log.txt即已完成 Zeo++ 计算)
log_count = 0
for anion_key in os.listdir(data_dir):
anion_dir = os.path.join(data_dir, anion_key)
if not os.path.isdir(anion_dir):
continue
for struct_name in os.listdir(anion_dir):
struct_dir = os.path.join(anion_dir, struct_name)
if os.path.isdir(struct_dir):
log_path = os.path.join(struct_dir, "log.txt")
if os.path.exists(log_path):
log_count += 1
status['structures_with_log'] = log_count
# 如果大部分结构都有 log.txt认为 Zeo++ 计算已完成
if log_count > 0 and log_count >= status['total_structures'] * 0.5:
status['has_zeo_results'] = True
return status
def print_workflow_status(status: dict):
"""打印工作流程状态"""
print("\n" + "" * 50)
print("【工作流程状态检测】")
print("" * 50)
if not status['has_processed_data']:
print(" Step 1 (扩胞+化合价): ❌ 未完成")
print(" Step 2-4 (Zeo++ 计算): ❌ 未完成")
print(" Step 5 (结果筛选): ❌ 未完成")
else:
print(f" Step 1 (扩胞+化合价): ✅ 已完成 ({status['total_structures']} 个结构)")
if status['has_zeo_results']:
print(f" Step 2-4 (Zeo++ 计算): ✅ 已完成 ({status['structures_with_log']}/{status['total_structures']} 有日志)")
print(" Step 5 (结果筛选): ⏳ 可执行")
else:
print(f" Step 2-4 (Zeo++ 计算): ⏳ 可执行 ({status['structures_with_log']}/{status['total_structures']} 有日志)")
print(" Step 5 (结果筛选): ❌ 需先完成 Zeo++ 计算")
print("" * 50)
def get_user_choice(status: dict) -> str:
"""
根据工作流程状态获取用户选择
Returns:
'step1': 从头开始(数据库分析 + 扩胞)
'step2': 从 Zeo++ 计算开始
'step5': 直接进行结果筛选
'exit': 退出
"""
print("\n请选择操作:")
options = []
if status['has_zeo_results']:
options.append(('5', '直接进行结果筛选 (Step 5)'))
options.append(('2', '重新运行 Zeo++ 计算 (Step 2-4)'))
options.append(('1', '从头开始 (数据库分析 + 扩胞)'))
elif status['has_processed_data']:
options.append(('2', '运行 Zeo++ 计算 (Step 2-4)'))
options.append(('1', '从头开始 (数据库分析 + 扩胞)'))
else:
options.append(('1', '从头开始 (数据库分析 + 扩胞)'))
options.append(('q', '退出'))
for key, desc in options:
print(f" {key}. {desc}")
choice = input("\n请选择 [默认: " + options[0][0] + "]: ").strip().lower()
if not choice:
choice = options[0][0]
if choice == 'q':
return 'exit'
elif choice == '5':
return 'step5'
elif choice == '2':
return 'step2'
else:
return 'step1'
def run_step1_database_analysis(env: dict, cation: str) -> dict:
"""
Step 1: 数据库分析与扩胞处理
Returns:
处理参数字典,如果用户取消则返回 None
"""
print("\n" + "" * 60)
print("【Step 1: 数据库分析与扩胞处理】")
print("" * 60)
# 数据库路径
while True:
db_path = input("\n📂 请输入数据库路径: ").strip()
@@ -48,9 +191,8 @@ def get_user_input(env: dict):
print(f"❌ 路径不存在: {db_path}")
# 检测当前Conda环境路径
conda_env_path = env.get('conda_prefix', '')
if not conda_env_path:
conda_env_path = os.path.expanduser("~/anaconda3/envs/screen")
default_conda = "/cluster/home/koko125/anaconda3/envs/screen"
conda_env_path = env.get('conda_env', '') or default_conda
print(f"\n检测到Conda环境: {conda_env_path}")
custom_env = input(f"使用此环境? [Y/n] 或输入其他路径: ").strip()
@@ -59,9 +201,6 @@ def get_user_input(env: dict):
elif custom_env and custom_env.lower() != 'y':
conda_env_path = custom_env
# 目标阳离子
cation = input("🎯 请输入目标阳离子 [默认: Li]: ").strip() or "Li"
# 目标阴离子
anion_input = input("🎯 请输入目标阴离子 (逗号分隔) [默认: O,S,Cl,Br]: ").strip()
anions = set(a.strip() for a in anion_input.split(',')) if anion_input else {'O', 'S', 'Cl', 'Br'}
@@ -79,43 +218,18 @@ def get_user_input(env: dict):
print("【并行计算配置】")
default_cores = min(env['total_cores'], 32)
cores_input = input(f"💻 最大可用核数 [默认: {default_cores}]: ").strip()
max_cores = int(cores_input) if cores_input.isdigit() else default_cores
cores_input = input(f"💻 最大可用核数/Worker数 [默认: {default_cores}]: ").strip()
max_workers = int(cores_input) if cores_input.isdigit() else default_cores
print("\n任务复杂度 (影响每个Worker分配的核数):")
print(" 1. 低 (1核/Worker) - 简单IO操作")
print(" 2. 中 (2核/Worker) - 结构解析+检查 [默认]")
print(" 3. 高 (4核/Worker) - 复杂计算")
complexity_choice = input("请选择 [1/2/3]: ").strip()
complexity = {'1': 'low', '2': 'medium', '3': 'high', '': 'medium'}.get(complexity_choice, 'medium')
# 执行模式
use_slurm = False
if env['has_slurm']:
slurm_choice = input("\n是否使用SLURM提交作业? [y/N]: ").strip().lower()
use_slurm = slurm_choice == 'y'
return {
params = {
'database_path': db_path,
'target_cation': cation,
'target_anions': anions,
'anion_mode': anion_mode,
'max_cores': max_cores,
'task_complexity': complexity,
'use_slurm': use_slurm
'max_workers': max_workers,
'conda_env': conda_env_path,
}
def main():
"""主函数"""
print_banner()
# 环境检测
env = detect_and_show_environment()
# 获取用户输入
params = get_user_input(env)
print("\n" + "" * 60)
print("开始数据库分析...")
print("" * 60)
@@ -126,43 +240,404 @@ def main():
target_cation=params['target_cation'],
target_anions=params['target_anions'],
anion_mode=params['anion_mode'],
max_cores=params['max_cores'],
task_complexity=params['task_complexity']
max_cores=params['max_workers'],
task_complexity='medium'
)
print(f"\n发现 {len(analyzer.cif_files)} 个CIF文件")
if params['use_slurm']:
# SLURM模式
output_dir = input("输出目录 [默认: ./slurm_output]: ").strip() or "./slurm_output"
job_id = analyzer.analyze_slurm(output_dir=output_dir)
print(f"\n✅ SLURM作业已提交: {job_id}")
print(f" 使用 'squeue -j {job_id}' 查看状态")
print(f" 结果将保存到: {output_dir}")
else:
# 本地模式
report = analyzer.analyze(show_progress=True)
# 执行分析
report = analyzer.analyze(show_progress=True)
# 打印报告
ReportGenerator.print_report(report, detailed=True)
# 打印报告
ReportGenerator.print_report(report, detailed=True)
# 保存选项
save_choice = input("\n是否保存报告? [y/N]: ").strip().lower()
if save_choice == 'y':
output_path = input("报告路径 [默认: analysis_report.json]: ").strip()
output_path = output_path or "analysis_report.json"
report.save(output_path)
print(f"✅ 报告已保存到: {output_path}")
# 保存选项
save_choice = input("\n是否保存报告? [y/N]: ").strip().lower()
if save_choice == 'y':
output_path = input("报告路径 [默认: analysis_report.json]: ").strip()
output_path = output_path or "analysis_report.json"
report.save(output_path)
print(f"✅ 报告已保存到: {output_path}")
# CSV导出
csv_choice = input("是否导出详细CSV? [y/N]: ").strip().lower()
if csv_choice == 'y':
csv_path = input("CSV路径 [默认: analysis_details.csv]: ").strip()
csv_path = csv_path or "analysis_details.csv"
ReportGenerator.export_to_csv(report, csv_path)
# CSV导出
csv_choice = input("是否导出详细CSV? [y/N]: ").strip().lower()
if csv_choice == 'y':
csv_path = input("CSV路径 [默认: analysis_details.csv]: ").strip()
csv_path = csv_path or "analysis_details.csv"
ReportGenerator.export_to_csv(report, csv_path)
print("\n✅ 分析完成!")
# 生成最终数据库
process_choice = input("\n是否生成最终可用的数据库(扩胞+添加化合价)? [Y/n]: ").strip().lower()
if process_choice == 'n':
print("\n已跳过扩胞处理,可稍后继续")
return params
# 输出目录设置
print("\n输出目录设置:")
flat_dir = input(" 原始格式输出目录 [默认: workspace/processed]: ").strip()
flat_dir = flat_dir or "workspace/processed"
analysis_dir = input(" 分析格式输出目录 [默认: workspace/data]: ").strip()
analysis_dir = analysis_dir or "workspace/data"
# 扩胞保存数量
keep_input = input("\n扩胞结构保存数量 [默认: 1]: ").strip()
keep_number = int(keep_input) if keep_input.isdigit() and int(keep_input) > 0 else 1
# 扩胞精度选择
print("\n扩胞计算精度:")
print(" 1. 高精度 (精确分数)")
print(" 2. 普通精度 (分母≤100)")
print(" 3. 低精度 (分母≤10) [默认]")
print(" 4. 极低精度 (分母≤5)")
precision_choice = input("请选择 [1/2/3/4]: ").strip()
calculate_type = {
'1': 'high', '2': 'normal', '3': 'low', '4': 'very_low', '': 'low'
}.get(precision_choice, 'low')
# 获取可处理文件
processable = report.get_processable_files(include_needs_expansion=True)
if not processable:
print("⚠️ 没有可处理的文件")
return params
print(f"\n发现 {len(processable)} 个可处理的文件")
# 准备文件列表和扩胞标记
input_files = [info.file_path for info in processable]
needs_expansion_flags = [info.needs_expansion for info in processable]
anion_types_list = [info.anion_types for info in processable]
direct_count = sum(1 for f in needs_expansion_flags if not f)
expansion_count = sum(1 for f in needs_expansion_flags if f)
print(f" - 可直接处理: {direct_count}")
print(f" - 需要扩胞: {expansion_count}")
print(f" - 扩胞保存数: {keep_number}")
confirm = input("\n确认开始处理? [Y/n]: ").strip().lower()
if confirm == 'n':
print("\n已取消处理")
return params
print("\n" + "" * 60)
print("开始处理结构文件...")
print("" * 60)
# 创建处理器
processor = StructureProcessor(
calculate_type=calculate_type,
keep_number=keep_number,
target_cation=params['target_cation']
)
# 创建输出目录
os.makedirs(flat_dir, exist_ok=True)
os.makedirs(analysis_dir, exist_ok=True)
results = []
total = len(input_files)
import shutil
for i, (input_file, needs_exp, anion_types) in enumerate(
zip(input_files, needs_expansion_flags, anion_types_list)
):
print(f"\r处理进度: {i+1}/{total} - {os.path.basename(input_file)}", end="")
# 处理文件到原始格式目录
result = processor.process_file(input_file, flat_dir, needs_exp)
results.append(result)
if result.success:
# 同时保存到分析格式目录
# 按阴离子类型创建子目录
anion_key = '+'.join(sorted(anion_types)) if anion_types else 'other'
anion_dir = os.path.join(analysis_dir, anion_key)
os.makedirs(anion_dir, exist_ok=True)
# 获取基础文件名
base_name = os.path.splitext(os.path.basename(input_file))[0]
# 创建以文件名命名的子目录
file_dir = os.path.join(anion_dir, base_name)
os.makedirs(file_dir, exist_ok=True)
# 复制生成的文件到分析格式目录
for output_file in result.output_files:
dst_path = os.path.join(file_dir, os.path.basename(output_file))
shutil.copy2(output_file, dst_path)
print()
# 统计结果
success_count = sum(1 for r in results if r.success)
fail_count = sum(1 for r in results if not r.success)
total_output = sum(len(r.output_files) for r in results if r.success)
print("\n" + "-" * 60)
print("【处理结果统计】")
print("-" * 60)
print(f" 成功处理: {success_count}")
print(f" 处理失败: {fail_count}")
print(f" 生成文件: {total_output}")
print(f"\n 原始格式目录: {flat_dir}")
print(f" 分析格式目录: {analysis_dir}")
print(f" └── 结构: data/阴离子类型/文件名/文件名.cif")
# 显示失败的文件
if fail_count > 0:
print("\n失败的文件:")
for r in results:
if not r.success:
print(f" - {os.path.basename(r.input_file)}: {r.error_message}")
print("\n✅ Step 1 完成!")
return params
def run_step2_zeo_analysis(params: dict, ws_manager: WorkspaceManager = None, workspace_info = None):
"""
Step 2-4: Zeo++ Voronoi 分析
Args:
params: 参数字典,包含 target_cation 等
ws_manager: 工作区管理器(可选,如果已有则直接使用)
workspace_info: 工作区信息(可选,如果已有则直接使用)
"""
print("\n" + "" * 60)
print("【Step 2-4: Zeo++ Voronoi 分析】")
print("" * 60)
# 如果没有传入 ws_manager则创建新的
if ws_manager is None:
# 工作区路径
workspace_path = input("\n工作区路径 [默认: workspace]: ").strip() or "workspace"
tool_dir = input("工具目录路径 [默认: tool]: ").strip() or "tool"
# 创建工作区管理器
ws_manager = WorkspaceManager(
workspace_path=workspace_path,
tool_dir=tool_dir,
target_cation=params.get('target_cation', 'Li')
)
# 如果没有传入 workspace_info则检查现有工作区
if workspace_info is None:
existing = ws_manager.check_existing_workspace()
if existing and existing.total_structures > 0:
ws_manager.print_workspace_summary(existing)
workspace_info = existing
else:
print("⚠️ 工作区数据目录不存在或为空")
print(f" 请先运行 Step 1 生成处理后的数据到: {ws_manager.data_dir}")
return
# 检查是否需要创建软链接
if workspace_info.linked_structures < workspace_info.total_structures:
print("\n正在创建配置文件软链接...")
workspace_info = ws_manager.setup_workspace(force_relink=False)
# 获取计算任务
tasks = ws_manager.get_computation_tasks(workspace_info)
if not tasks:
print("⚠️ 没有找到可计算的任务")
return
print(f"\n发现 {len(tasks)} 个计算任务")
# Zeo++ 环境配置
print("\n" + "-" * 50)
print("【Zeo++ 计算配置】")
default_zeo_env = "/cluster/home/koko125/anaconda3/envs/zeo"
zeo_env = input(f"Zeo++ Conda环境 [默认: {default_zeo_env}]: ").strip()
zeo_env = zeo_env or default_zeo_env
# SLURM 配置
partition = input("SLURM分区 [默认: cpu]: ").strip() or "cpu"
max_concurrent_input = input("最大并发任务数 [默认: 50]: ").strip()
max_concurrent = int(max_concurrent_input) if max_concurrent_input.isdigit() else 50
time_limit = input("单任务时间限制 [默认: 2:00:00]: ").strip() or "2:00:00"
# 创建配置
zeo_config = ZeoConfig(
conda_env=zeo_env,
partition=partition,
max_concurrent=max_concurrent,
time_limit=time_limit
)
# 确认执行
print("\n" + "-" * 50)
print("【计算任务确认】")
print(f" 总任务数: {len(tasks)}")
print(f" Conda环境: {zeo_config.conda_env}")
print(f" SLURM分区: {zeo_config.partition}")
print(f" 最大并发: {zeo_config.max_concurrent}")
print(f" 时间限制: {zeo_config.time_limit}")
confirm = input("\n确认提交计算任务? [Y/n]: ").strip().lower()
if confirm == 'n':
print("已取消")
return
# 创建执行器并运行
executor = ZeoExecutor(zeo_config)
log_dir = os.path.join(ws_manager.workspace_path, "slurm_logs")
results = executor.run_batch(tasks, output_dir=log_dir)
# 打印结果摘要
executor.print_results_summary(results)
# 保存结果
results_file = os.path.join(ws_manager.workspace_path, "zeo_results.json")
results_data = [
{
'task_id': r.task_id,
'structure_name': r.structure_name,
'cif_path': r.cif_path,
'success': r.success,
'output_files': r.output_files,
'error_message': r.error_message
}
for r in results
]
with open(results_file, 'w') as f:
json.dump(results_data, f, indent=2)
print(f"\n结果已保存到: {results_file}")
print("\n✅ Step 2-4 完成!")
def run_step5_result_processing(workspace_path: str = "workspace"):
"""
Step 5: 结果处理与筛选
Args:
workspace_path: 工作区路径
"""
print("\n" + "" * 60)
print("【Step 5: 结果处理与筛选】")
print("" * 60)
# 创建结果处理器
processor = ResultProcessor(workspace_path=workspace_path)
# 获取筛选条件
print("\n设置筛选条件:")
print(" (直接回车使用默认值,输入 0 表示不限制)")
# 最小渗透直径(默认改为 1.0
perc_input = input(" 最小渗透直径 (Å) [默认: 1.0]: ").strip()
min_percolation = float(perc_input) if perc_input else 1.0
# 最小 d 值
d_input = input(" 最小 d 值 [默认: 2.0]: ").strip()
min_d = float(d_input) if d_input else 2.0
# 最大节点长度
node_input = input(" 最大节点长度 (Å) [默认: 不限制]: ").strip()
max_node = float(node_input) if node_input else float('inf')
# 创建筛选条件
criteria = FilterCriteria(
min_percolation_diameter=min_percolation,
min_d_value=min_d,
max_node_length=max_node
)
# 执行处理
results, stats = processor.process_and_filter(
criteria=criteria,
save_csv=True,
copy_passed=True
)
# 打印摘要
processor.print_summary(results, stats)
# 显示输出位置
print("\n输出文件位置:")
print(f" 汇总 CSV: {workspace_path}/results/summary.csv")
print(f" 分类 CSV: {workspace_path}/results/阴离子类型/阴离子类型.csv")
print(f" 通过筛选: {workspace_path}/passed/阴离子类型/结构名/")
print("\n✅ Step 5 完成!")
def main():
"""主函数"""
print_banner()
# 环境检测
env = detect_and_show_environment()
# 询问目标阳离子
cation = input("\n🎯 请输入目标阳离子 [默认: Li]: ").strip() or "Li"
# 检测工作流程状态
status = detect_workflow_status(workspace_path="workspace", target_cation=cation)
print_workflow_status(status)
# 获取用户选择
choice = get_user_choice(status)
if choice == 'exit':
print("\n👋 再见!")
return
# 根据选择执行相应步骤
params = {'target_cation': cation}
if choice == 'step1':
# 从头开始
params = run_step1_database_analysis(env, cation)
if params is None:
return
# 询问是否继续
continue_choice = input("\n是否继续进行 Zeo++ 计算? [Y/n]: ").strip().lower()
if continue_choice == 'n':
print("\n可稍后运行程序继续 Zeo++ 计算")
return
# 重新检测状态
status = detect_workflow_status(workspace_path="workspace", target_cation=cation)
run_step2_zeo_analysis(params, status['ws_manager'], status['workspace_info'])
# 询问是否继续筛选
continue_choice = input("\n是否继续进行结果筛选? [Y/n]: ").strip().lower()
if continue_choice == 'n':
print("\n可稍后运行程序继续结果筛选")
return
run_step5_result_processing("workspace")
elif choice == 'step2':
# 从 Zeo++ 计算开始
run_step2_zeo_analysis(params, status['ws_manager'], status['workspace_info'])
# 询问是否继续筛选
continue_choice = input("\n是否继续进行结果筛选? [Y/n]: ").strip().lower()
if continue_choice == 'n':
print("\n可稍后运行程序继续结果筛选")
return
run_step5_result_processing("workspace")
elif choice == 'step5':
# 直接进行结果筛选
run_step5_result_processing("workspace")
print("\n✅ 全部完成!")
if __name__ == "__main__":
main()
main()