Files
screen/main.py
2025-12-16 11:36:49 +08:00

644 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
高通量筛选与扩胞项目 - 主入口(支持断点续做)
"""
import os
import sys
import json
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
from src.analysis.database_analyzer import DatabaseAnalyzer
from src.analysis.report_generator import ReportGenerator
from src.core.executor import TaskExecutor
from src.preprocessing.processor import StructureProcessor
from src.computation.workspace_manager import WorkspaceManager
from src.computation.zeo_executor import ZeoExecutor, ZeoConfig
from src.computation.result_processor import ResultProcessor, FilterCriteria
def print_banner():
print("""
╔═══════════════════════════════════════════════════════════════════╗
║ 高通量筛选与扩胞项目 - 数据库分析工具 v2.2 ║
║ 支持断点续做与高性能并行计算 ║
╚═══════════════════════════════════════════════════════════════════╝
""")
def detect_and_show_environment():
"""检测并显示环境信息"""
env = TaskExecutor.detect_environment()
print("【运行环境检测】")
print(f" 主机名: {env['hostname']}")
print(f" 本地CPU核数: {env['total_cores']}")
print(f" SLURM集群: {'✅ 可用' if env['has_slurm'] else '❌ 不可用'}")
if env['has_slurm'] and env['slurm_partitions']:
print(f" 可用分区: {', '.join(env['slurm_partitions'])}")
if env['conda_env']:
print(f" 当前Conda: {env['conda_env']}")
return env
def detect_workflow_status(workspace_path: str = "workspace", target_cation: str = "Li") -> dict:
"""
检测工作流程状态,确定可以从哪一步继续
Returns:
dict: {
'has_processed_data': bool, # 是否有扩胞处理后的数据
'has_zeo_results': bool, # 是否有 Zeo++ 计算结果
'total_structures': int, # 总结构数
'structures_with_log': int, # 有 log.txt 的结构数
'workspace_info': object, # 工作区信息
'ws_manager': object # 工作区管理器
}
"""
status = {
'has_processed_data': False,
'has_zeo_results': False,
'total_structures': 0,
'structures_with_log': 0,
'workspace_info': None,
'ws_manager': None
}
# 创建工作区管理器
ws_manager = WorkspaceManager(
workspace_path=workspace_path,
tool_dir="tool",
target_cation=target_cation
)
status['ws_manager'] = ws_manager
# 检查是否有处理后的数据
data_dir = os.path.join(workspace_path, "data")
if os.path.exists(data_dir):
existing = ws_manager.check_existing_workspace()
if existing and existing.total_structures > 0:
status['has_processed_data'] = True
status['total_structures'] = existing.total_structures
status['workspace_info'] = existing
# 检查有多少结构有 log.txt即已完成 Zeo++ 计算)
log_count = 0
for anion_key in os.listdir(data_dir):
anion_dir = os.path.join(data_dir, anion_key)
if not os.path.isdir(anion_dir):
continue
for struct_name in os.listdir(anion_dir):
struct_dir = os.path.join(anion_dir, struct_name)
if os.path.isdir(struct_dir):
log_path = os.path.join(struct_dir, "log.txt")
if os.path.exists(log_path):
log_count += 1
status['structures_with_log'] = log_count
# 如果大部分结构都有 log.txt认为 Zeo++ 计算已完成
if log_count > 0 and log_count >= status['total_structures'] * 0.5:
status['has_zeo_results'] = True
return status
def print_workflow_status(status: dict):
"""打印工作流程状态"""
print("\n" + "" * 50)
print("【工作流程状态检测】")
print("" * 50)
if not status['has_processed_data']:
print(" Step 1 (扩胞+化合价): ❌ 未完成")
print(" Step 2-4 (Zeo++ 计算): ❌ 未完成")
print(" Step 5 (结果筛选): ❌ 未完成")
else:
print(f" Step 1 (扩胞+化合价): ✅ 已完成 ({status['total_structures']} 个结构)")
if status['has_zeo_results']:
print(f" Step 2-4 (Zeo++ 计算): ✅ 已完成 ({status['structures_with_log']}/{status['total_structures']} 有日志)")
print(" Step 5 (结果筛选): ⏳ 可执行")
else:
print(f" Step 2-4 (Zeo++ 计算): ⏳ 可执行 ({status['structures_with_log']}/{status['total_structures']} 有日志)")
print(" Step 5 (结果筛选): ❌ 需先完成 Zeo++ 计算")
print("" * 50)
def get_user_choice(status: dict) -> str:
"""
根据工作流程状态获取用户选择
Returns:
'step1': 从头开始(数据库分析 + 扩胞)
'step2': 从 Zeo++ 计算开始
'step5': 直接进行结果筛选
'exit': 退出
"""
print("\n请选择操作:")
options = []
if status['has_zeo_results']:
options.append(('5', '直接进行结果筛选 (Step 5)'))
options.append(('2', '重新运行 Zeo++ 计算 (Step 2-4)'))
options.append(('1', '从头开始 (数据库分析 + 扩胞)'))
elif status['has_processed_data']:
options.append(('2', '运行 Zeo++ 计算 (Step 2-4)'))
options.append(('1', '从头开始 (数据库分析 + 扩胞)'))
else:
options.append(('1', '从头开始 (数据库分析 + 扩胞)'))
options.append(('q', '退出'))
for key, desc in options:
print(f" {key}. {desc}")
choice = input("\n请选择 [默认: " + options[0][0] + "]: ").strip().lower()
if not choice:
choice = options[0][0]
if choice == 'q':
return 'exit'
elif choice == '5':
return 'step5'
elif choice == '2':
return 'step2'
else:
return 'step1'
def run_step1_database_analysis(env: dict, cation: str) -> dict:
"""
Step 1: 数据库分析与扩胞处理
Returns:
处理参数字典,如果用户取消则返回 None
"""
print("\n" + "" * 60)
print("【Step 1: 数据库分析与扩胞处理】")
print("" * 60)
# 数据库路径
while True:
db_path = input("\n📂 请输入数据库路径: ").strip()
if os.path.exists(db_path):
break
print(f"❌ 路径不存在: {db_path}")
# 检测当前Conda环境路径
default_conda = "/cluster/home/koko125/anaconda3/envs/screen"
conda_env_path = env.get('conda_env', '') or default_conda
print(f"\n检测到Conda环境: {conda_env_path}")
custom_env = input(f"使用此环境? [Y/n] 或输入其他路径: ").strip()
if custom_env.lower() == 'n':
conda_env_path = input("请输入Conda环境完整路径: ").strip()
elif custom_env and custom_env.lower() != 'y':
conda_env_path = custom_env
# 目标阴离子
anion_input = input("🎯 请输入目标阴离子 (逗号分隔) [默认: O,S,Cl,Br]: ").strip()
anions = set(a.strip() for a in anion_input.split(',')) if anion_input else {'O', 'S', 'Cl', 'Br'}
# 阴离子模式
print("\n阴离子模式:")
print(" 1. 仅单一阴离子")
print(" 2. 仅复合阴离子")
print(" 3. 全部 (默认)")
mode_choice = input("请选择 [1/2/3]: ").strip()
anion_mode = {'1': 'single', '2': 'mixed', '3': 'all', '': 'all'}.get(mode_choice, 'all')
# 并行配置
print("\n" + "" * 50)
print("【并行计算配置】")
default_cores = min(env['total_cores'], 32)
cores_input = input(f"💻 最大可用核数/Worker数 [默认: {default_cores}]: ").strip()
max_workers = int(cores_input) if cores_input.isdigit() else default_cores
params = {
'database_path': db_path,
'target_cation': cation,
'target_anions': anions,
'anion_mode': anion_mode,
'max_workers': max_workers,
'conda_env': conda_env_path,
}
print("\n" + "" * 60)
print("开始数据库分析...")
print("" * 60)
# 创建分析器
analyzer = DatabaseAnalyzer(
database_path=params['database_path'],
target_cation=params['target_cation'],
target_anions=params['target_anions'],
anion_mode=params['anion_mode'],
max_cores=params['max_workers'],
task_complexity='medium'
)
print(f"\n发现 {len(analyzer.cif_files)} 个CIF文件")
# 执行分析
report = analyzer.analyze(show_progress=True)
# 打印报告
ReportGenerator.print_report(report, detailed=True)
# 保存选项
save_choice = input("\n是否保存报告? [y/N]: ").strip().lower()
if save_choice == 'y':
output_path = input("报告路径 [默认: analysis_report.json]: ").strip()
output_path = output_path or "analysis_report.json"
report.save(output_path)
print(f"✅ 报告已保存到: {output_path}")
# CSV导出
csv_choice = input("是否导出详细CSV? [y/N]: ").strip().lower()
if csv_choice == 'y':
csv_path = input("CSV路径 [默认: analysis_details.csv]: ").strip()
csv_path = csv_path or "analysis_details.csv"
ReportGenerator.export_to_csv(report, csv_path)
# 生成最终数据库
process_choice = input("\n是否生成最终可用的数据库(扩胞+添加化合价)? [Y/n]: ").strip().lower()
if process_choice == 'n':
print("\n已跳过扩胞处理,可稍后继续")
return params
# 输出目录设置
print("\n输出目录设置:")
flat_dir = input(" 原始格式输出目录 [默认: workspace/processed]: ").strip()
flat_dir = flat_dir or "workspace/processed"
analysis_dir = input(" 分析格式输出目录 [默认: workspace/data]: ").strip()
analysis_dir = analysis_dir or "workspace/data"
# 扩胞保存数量
keep_input = input("\n扩胞结构保存数量 [默认: 1]: ").strip()
keep_number = int(keep_input) if keep_input.isdigit() and int(keep_input) > 0 else 1
# 扩胞精度选择
print("\n扩胞计算精度:")
print(" 1. 高精度 (精确分数)")
print(" 2. 普通精度 (分母≤100)")
print(" 3. 低精度 (分母≤10) [默认]")
print(" 4. 极低精度 (分母≤5)")
precision_choice = input("请选择 [1/2/3/4]: ").strip()
calculate_type = {
'1': 'high', '2': 'normal', '3': 'low', '4': 'very_low', '': 'low'
}.get(precision_choice, 'low')
# 获取可处理文件
processable = report.get_processable_files(include_needs_expansion=True)
if not processable:
print("⚠️ 没有可处理的文件")
return params
print(f"\n发现 {len(processable)} 个可处理的文件")
# 准备文件列表和扩胞标记
input_files = [info.file_path for info in processable]
needs_expansion_flags = [info.needs_expansion for info in processable]
anion_types_list = [info.anion_types for info in processable]
direct_count = sum(1 for f in needs_expansion_flags if not f)
expansion_count = sum(1 for f in needs_expansion_flags if f)
print(f" - 可直接处理: {direct_count}")
print(f" - 需要扩胞: {expansion_count}")
print(f" - 扩胞保存数: {keep_number}")
confirm = input("\n确认开始处理? [Y/n]: ").strip().lower()
if confirm == 'n':
print("\n已取消处理")
return params
print("\n" + "" * 60)
print("开始处理结构文件...")
print("" * 60)
# 创建处理器
processor = StructureProcessor(
calculate_type=calculate_type,
keep_number=keep_number,
target_cation=params['target_cation']
)
# 创建输出目录
os.makedirs(flat_dir, exist_ok=True)
os.makedirs(analysis_dir, exist_ok=True)
results = []
total = len(input_files)
import shutil
for i, (input_file, needs_exp, anion_types) in enumerate(
zip(input_files, needs_expansion_flags, anion_types_list)
):
print(f"\r处理进度: {i+1}/{total} - {os.path.basename(input_file)}", end="")
# 处理文件到原始格式目录
result = processor.process_file(input_file, flat_dir, needs_exp)
results.append(result)
if result.success:
# 同时保存到分析格式目录
# 按阴离子类型创建子目录
anion_key = '+'.join(sorted(anion_types)) if anion_types else 'other'
anion_dir = os.path.join(analysis_dir, anion_key)
os.makedirs(anion_dir, exist_ok=True)
# 获取基础文件名
base_name = os.path.splitext(os.path.basename(input_file))[0]
# 创建以文件名命名的子目录
file_dir = os.path.join(anion_dir, base_name)
os.makedirs(file_dir, exist_ok=True)
# 复制生成的文件到分析格式目录
for output_file in result.output_files:
dst_path = os.path.join(file_dir, os.path.basename(output_file))
shutil.copy2(output_file, dst_path)
print()
# 统计结果
success_count = sum(1 for r in results if r.success)
fail_count = sum(1 for r in results if not r.success)
total_output = sum(len(r.output_files) for r in results if r.success)
print("\n" + "-" * 60)
print("【处理结果统计】")
print("-" * 60)
print(f" 成功处理: {success_count}")
print(f" 处理失败: {fail_count}")
print(f" 生成文件: {total_output}")
print(f"\n 原始格式目录: {flat_dir}")
print(f" 分析格式目录: {analysis_dir}")
print(f" └── 结构: data/阴离子类型/文件名/文件名.cif")
# 显示失败的文件
if fail_count > 0:
print("\n失败的文件:")
for r in results:
if not r.success:
print(f" - {os.path.basename(r.input_file)}: {r.error_message}")
print("\n✅ Step 1 完成!")
return params
def run_step2_zeo_analysis(params: dict, ws_manager: WorkspaceManager = None, workspace_info = None):
"""
Step 2-4: Zeo++ Voronoi 分析
Args:
params: 参数字典,包含 target_cation 等
ws_manager: 工作区管理器(可选,如果已有则直接使用)
workspace_info: 工作区信息(可选,如果已有则直接使用)
"""
print("\n" + "" * 60)
print("【Step 2-4: Zeo++ Voronoi 分析】")
print("" * 60)
# 如果没有传入 ws_manager则创建新的
if ws_manager is None:
# 工作区路径
workspace_path = input("\n工作区路径 [默认: workspace]: ").strip() or "workspace"
tool_dir = input("工具目录路径 [默认: tool]: ").strip() or "tool"
# 创建工作区管理器
ws_manager = WorkspaceManager(
workspace_path=workspace_path,
tool_dir=tool_dir,
target_cation=params.get('target_cation', 'Li')
)
# 如果没有传入 workspace_info则检查现有工作区
if workspace_info is None:
existing = ws_manager.check_existing_workspace()
if existing and existing.total_structures > 0:
ws_manager.print_workspace_summary(existing)
workspace_info = existing
else:
print("⚠️ 工作区数据目录不存在或为空")
print(f" 请先运行 Step 1 生成处理后的数据到: {ws_manager.data_dir}")
return
# 检查是否需要创建软链接
if workspace_info.linked_structures < workspace_info.total_structures:
print("\n正在创建配置文件软链接...")
workspace_info = ws_manager.setup_workspace(force_relink=False)
# 获取计算任务
tasks = ws_manager.get_computation_tasks(workspace_info)
if not tasks:
print("⚠️ 没有找到可计算的任务")
return
print(f"\n发现 {len(tasks)} 个计算任务")
# Zeo++ 环境配置
print("\n" + "-" * 50)
print("【Zeo++ 计算配置】")
default_zeo_env = "/cluster/home/koko125/anaconda3/envs/zeo"
zeo_env = input(f"Zeo++ Conda环境 [默认: {default_zeo_env}]: ").strip()
zeo_env = zeo_env or default_zeo_env
# SLURM 配置
partition = input("SLURM分区 [默认: cpu]: ").strip() or "cpu"
max_concurrent_input = input("最大并发任务数 [默认: 50]: ").strip()
max_concurrent = int(max_concurrent_input) if max_concurrent_input.isdigit() else 50
time_limit = input("单任务时间限制 [默认: 2:00:00]: ").strip() or "2:00:00"
# 创建配置
zeo_config = ZeoConfig(
conda_env=zeo_env,
partition=partition,
max_concurrent=max_concurrent,
time_limit=time_limit
)
# 确认执行
print("\n" + "-" * 50)
print("【计算任务确认】")
print(f" 总任务数: {len(tasks)}")
print(f" Conda环境: {zeo_config.conda_env}")
print(f" SLURM分区: {zeo_config.partition}")
print(f" 最大并发: {zeo_config.max_concurrent}")
print(f" 时间限制: {zeo_config.time_limit}")
confirm = input("\n确认提交计算任务? [Y/n]: ").strip().lower()
if confirm == 'n':
print("已取消")
return
# 创建执行器并运行
executor = ZeoExecutor(zeo_config)
log_dir = os.path.join(ws_manager.workspace_path, "slurm_logs")
results = executor.run_batch(tasks, output_dir=log_dir)
# 打印结果摘要
executor.print_results_summary(results)
# 保存结果
results_file = os.path.join(ws_manager.workspace_path, "zeo_results.json")
results_data = [
{
'task_id': r.task_id,
'structure_name': r.structure_name,
'cif_path': r.cif_path,
'success': r.success,
'output_files': r.output_files,
'error_message': r.error_message
}
for r in results
]
with open(results_file, 'w') as f:
json.dump(results_data, f, indent=2)
print(f"\n结果已保存到: {results_file}")
print("\n✅ Step 2-4 完成!")
def run_step5_result_processing(workspace_path: str = "workspace"):
"""
Step 5: 结果处理与筛选
Args:
workspace_path: 工作区路径
"""
print("\n" + "" * 60)
print("【Step 5: 结果处理与筛选】")
print("" * 60)
# 创建结果处理器
processor = ResultProcessor(workspace_path=workspace_path)
# 获取筛选条件
print("\n设置筛选条件:")
print(" (直接回车使用默认值,输入 0 表示不限制)")
# 最小渗透直径(默认改为 1.0
perc_input = input(" 最小渗透直径 (Å) [默认: 1.0]: ").strip()
min_percolation = float(perc_input) if perc_input else 1.0
# 最小 d 值
d_input = input(" 最小 d 值 [默认: 2.0]: ").strip()
min_d = float(d_input) if d_input else 2.0
# 最大节点长度
node_input = input(" 最大节点长度 (Å) [默认: 不限制]: ").strip()
max_node = float(node_input) if node_input else float('inf')
# 创建筛选条件
criteria = FilterCriteria(
min_percolation_diameter=min_percolation,
min_d_value=min_d,
max_node_length=max_node
)
# 执行处理
results, stats = processor.process_and_filter(
criteria=criteria,
save_csv=True,
copy_passed=True
)
# 打印摘要
processor.print_summary(results, stats)
# 显示输出位置
print("\n输出文件位置:")
print(f" 汇总 CSV: {workspace_path}/results/summary.csv")
print(f" 分类 CSV: {workspace_path}/results/阴离子类型/阴离子类型.csv")
print(f" 通过筛选: {workspace_path}/passed/阴离子类型/结构名/")
print("\n✅ Step 5 完成!")
def main():
"""主函数"""
print_banner()
# 环境检测
env = detect_and_show_environment()
# 询问目标阳离子
cation = input("\n🎯 请输入目标阳离子 [默认: Li]: ").strip() or "Li"
# 检测工作流程状态
status = detect_workflow_status(workspace_path="workspace", target_cation=cation)
print_workflow_status(status)
# 获取用户选择
choice = get_user_choice(status)
if choice == 'exit':
print("\n👋 再见!")
return
# 根据选择执行相应步骤
params = {'target_cation': cation}
if choice == 'step1':
# 从头开始
params = run_step1_database_analysis(env, cation)
if params is None:
return
# 询问是否继续
continue_choice = input("\n是否继续进行 Zeo++ 计算? [Y/n]: ").strip().lower()
if continue_choice == 'n':
print("\n可稍后运行程序继续 Zeo++ 计算")
return
# 重新检测状态
status = detect_workflow_status(workspace_path="workspace", target_cation=cation)
run_step2_zeo_analysis(params, status['ws_manager'], status['workspace_info'])
# 询问是否继续筛选
continue_choice = input("\n是否继续进行结果筛选? [Y/n]: ").strip().lower()
if continue_choice == 'n':
print("\n可稍后运行程序继续结果筛选")
return
run_step5_result_processing("workspace")
elif choice == 'step2':
# 从 Zeo++ 计算开始
run_step2_zeo_analysis(params, status['ws_manager'], status['workspace_info'])
# 询问是否继续筛选
continue_choice = input("\n是否继续进行结果筛选? [Y/n]: ").strip().lower()
if continue_choice == 'n':
print("\n可稍后运行程序继续结果筛选")
return
run_step5_result_processing("workspace")
elif choice == 'step5':
# 直接进行结果筛选
run_step5_result_processing("workspace")
print("\n✅ 全部完成!")
if __name__ == "__main__":
main()