预处理增加并行计算
This commit is contained in:
157
main.py
157
main.py
@@ -1,71 +1,112 @@
|
||||
"""
|
||||
高通量筛选与扩胞项目 - 主入口
|
||||
交互式命令行界面
|
||||
高通量筛选与扩胞项目 - 主入口(支持并行)
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
|
||||
# 添加 src 到路径
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
|
||||
|
||||
from src.analysis.database_analyzer import DatabaseAnalyzer
|
||||
from src.analysis.report_generator import ReportGenerator
|
||||
from analysis.database_analyzer import DatabaseAnalyzer
|
||||
from analysis.report_generator import ReportGenerator
|
||||
from core.scheduler import ParallelScheduler
|
||||
|
||||
|
||||
def get_user_input():
|
||||
def print_banner():
|
||||
print("""
|
||||
╔═══════════════════════════════════════════════════════════════════╗
|
||||
║ 高通量筛选与扩胞项目 - 数据库分析工具 v2.0 ║
|
||||
║ 支持高性能并行计算 ║
|
||||
╚═══════════════════════════════════════════════════════════════════╝
|
||||
""")
|
||||
|
||||
|
||||
def detect_and_show_environment():
|
||||
"""检测并显示环境信息"""
|
||||
env = ParallelScheduler.detect_environment()
|
||||
|
||||
print("【运行环境检测】")
|
||||
print(f" 主机名: {env['hostname']}")
|
||||
print(f" 本地CPU核数: {env['total_cores']}")
|
||||
print(f" SLURM集群: {'✅ 可用' if env['has_slurm'] else '❌ 不可用'}")
|
||||
|
||||
if env['has_slurm'] and env['slurm_partitions']:
|
||||
print(f" 可用分区:")
|
||||
for p in env['slurm_partitions']:
|
||||
print(f" - {p['name']}: {p['nodes']}节点, {p['cores_per_node']}核/节点")
|
||||
|
||||
return env
|
||||
|
||||
|
||||
def get_user_input(env: dict):
|
||||
"""获取用户输入"""
|
||||
print("\n" + "=" * 70)
|
||||
print(" 高通量筛选与扩胞项目 - 数据库分析工具")
|
||||
print("=" * 70)
|
||||
|
||||
# 1. 获取数据库路径
|
||||
# 数据库路径
|
||||
while True:
|
||||
db_path = input("\n请输入数据库路径: ").strip()
|
||||
db_path = input("\n📂 请输入数据库路径: ").strip()
|
||||
if os.path.exists(db_path):
|
||||
break
|
||||
print(f"❌ 路径不存在: {db_path}")
|
||||
|
||||
# 2. 获取目标阳离子
|
||||
cation = input("请输入目标阳离子 [默认: Li]: ").strip() or "Li"
|
||||
# 目标阳离子
|
||||
cation = input("🎯 请输入目标阳离子 [默认: Li]: ").strip() or "Li"
|
||||
|
||||
# 3. 获取目标阴离子
|
||||
anion_input = input("请输入目标阴离子 (用逗号分隔) [默认: O,S,Cl,Br]: ").strip()
|
||||
if anion_input:
|
||||
anions = set(a.strip() for a in anion_input.split(','))
|
||||
else:
|
||||
anions = {'O', 'S', 'Cl', 'Br'}
|
||||
# 目标阴离子
|
||||
anion_input = input("🎯 请输入目标阴离子 (逗号分隔) [默认: O,S,Cl,Br]: ").strip()
|
||||
anions = set(a.strip() for a in anion_input.split(',')) if anion_input else {'O', 'S', 'Cl', 'Br'}
|
||||
|
||||
# 4. 选择阴离子模式
|
||||
print("\n阴离子模式选择:")
|
||||
print(" 1. 仅单一阴离子化合物")
|
||||
print(" 2. 仅复合阴离子化合物")
|
||||
# 阴离子模式
|
||||
print("\n阴离子模式:")
|
||||
print(" 1. 仅单一阴离子")
|
||||
print(" 2. 仅复合阴离子")
|
||||
print(" 3. 全部 (默认)")
|
||||
mode_choice = input("请选择 [1/2/3]: ").strip()
|
||||
anion_mode = {'1': 'single', '2': 'mixed', '3': 'all', '': 'all'}.get(mode_choice, 'all')
|
||||
|
||||
mode_map = {'1': 'single', '2': 'mixed', '3': 'all', '': 'all'}
|
||||
anion_mode = mode_map.get(mode_choice, 'all')
|
||||
# 并行配置
|
||||
print("\n" + "─" * 50)
|
||||
print("【并行计算配置】")
|
||||
|
||||
# 5. 并行数
|
||||
n_jobs_input = input("并行线程数 [默认: 4]: ").strip()
|
||||
n_jobs = int(n_jobs_input) if n_jobs_input.isdigit() else 4
|
||||
default_cores = min(env['total_cores'], 32)
|
||||
cores_input = input(f"💻 最大可用核数 [默认: {default_cores}]: ").strip()
|
||||
max_cores = int(cores_input) if cores_input.isdigit() else default_cores
|
||||
|
||||
print("\n任务复杂度 (影响每个Worker分配的核数):")
|
||||
print(" 1. 低 (1核/Worker) - 简单IO操作")
|
||||
print(" 2. 中 (2核/Worker) - 结构解析+检查 [默认]")
|
||||
print(" 3. 高 (4核/Worker) - 复杂计算")
|
||||
complexity_choice = input("请选择 [1/2/3]: ").strip()
|
||||
complexity = {'1': 'low', '2': 'medium', '3': 'high', '': 'medium'}.get(complexity_choice, 'medium')
|
||||
|
||||
# 执行模式
|
||||
use_slurm = False
|
||||
if env['has_slurm']:
|
||||
slurm_choice = input("\n是否使用SLURM提交作业? [y/N]: ").strip().lower()
|
||||
use_slurm = slurm_choice == 'y'
|
||||
|
||||
return {
|
||||
'database_path': db_path,
|
||||
'target_cation': cation,
|
||||
'target_anions': anions,
|
||||
'anion_mode': anion_mode,
|
||||
'n_jobs': n_jobs
|
||||
'max_cores': max_cores,
|
||||
'task_complexity': complexity,
|
||||
'use_slurm': use_slurm
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
# 获取用户输入
|
||||
params = get_user_input()
|
||||
print_banner()
|
||||
|
||||
print("\n" + "-" * 70)
|
||||
print("开始分析数据库...")
|
||||
print("-" * 70)
|
||||
# 环境检测
|
||||
env = detect_and_show_environment()
|
||||
|
||||
# 获取用户输入
|
||||
params = get_user_input(env)
|
||||
|
||||
print("\n" + "═" * 60)
|
||||
print("开始数据库分析...")
|
||||
print("═" * 60)
|
||||
|
||||
# 创建分析器
|
||||
analyzer = DatabaseAnalyzer(
|
||||
@@ -73,30 +114,42 @@ def main():
|
||||
target_cation=params['target_cation'],
|
||||
target_anions=params['target_anions'],
|
||||
anion_mode=params['anion_mode'],
|
||||
n_jobs=params['n_jobs']
|
||||
max_cores=params['max_cores'],
|
||||
task_complexity=params['task_complexity']
|
||||
)
|
||||
|
||||
# 执行分析
|
||||
report = analyzer.analyze(show_progress=True)
|
||||
print(f"\n发现 {len(analyzer.cif_files)} 个CIF文件")
|
||||
|
||||
# 打印报告
|
||||
ReportGenerator.print_report(report, detailed=True)
|
||||
if params['use_slurm']:
|
||||
# SLURM模式
|
||||
output_dir = input("输出目录 [默认: ./slurm_output]: ").strip() or "./slurm_output"
|
||||
job_id = analyzer.analyze_slurm(output_dir=output_dir)
|
||||
print(f"\n✅ SLURM作业已提交: {job_id}")
|
||||
print(f" 使用 'squeue -j {job_id}' 查看状态")
|
||||
print(f" 结果将保存到: {output_dir}")
|
||||
else:
|
||||
# 本地模式
|
||||
report = analyzer.analyze(show_progress=True)
|
||||
|
||||
# 询问是否导出
|
||||
export = input("\n是否导出详细结果到CSV? [y/N]: ").strip().lower()
|
||||
if export == 'y':
|
||||
output_path = input("输出文件路径 [默认: analysis_report.csv]: ").strip()
|
||||
output_path = output_path or "analysis_report.csv"
|
||||
ReportGenerator.export_to_csv(report, output_path)
|
||||
# 打印报告
|
||||
ReportGenerator.print_report(report, detailed=True)
|
||||
|
||||
# 询问是否继续处理
|
||||
print("\n" + "-" * 70)
|
||||
proceed = input("是否继续进行预处理? [y/N]: ").strip().lower()
|
||||
if proceed == 'y':
|
||||
print("预处理功能将在下一阶段实现...")
|
||||
# TODO: 调用预处理模块
|
||||
# 保存选项
|
||||
save_choice = input("\n是否保存报告? [y/N]: ").strip().lower()
|
||||
if save_choice == 'y':
|
||||
output_path = input("报告路径 [默认: analysis_report.json]: ").strip()
|
||||
output_path = output_path or "analysis_report.json"
|
||||
report.save(output_path)
|
||||
print(f"✅ 报告已保存到: {output_path}")
|
||||
|
||||
print("\n分析完成!")
|
||||
# CSV导出
|
||||
csv_choice = input("是否导出详细CSV? [y/N]: ").strip().lower()
|
||||
if csv_choice == 'y':
|
||||
csv_path = input("CSV路径 [默认: analysis_details.csv]: ").strip()
|
||||
csv_path = csv_path or "analysis_details.csv"
|
||||
ReportGenerator.export_to_csv(report, csv_path)
|
||||
|
||||
print("\n✅ 分析完成!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user