预处理增加并行计算

This commit is contained in:
2025-12-14 15:53:11 +08:00
parent 1fee324c90
commit f27fd3e3ce
2 changed files with 123 additions and 56 deletions

12
main.py
View File

@@ -47,6 +47,18 @@ def get_user_input(env: dict):
break
print(f"❌ 路径不存在: {db_path}")
# 检测当前Conda环境路径
conda_env_path = env.get('conda_prefix', '')
if not conda_env_path:
conda_env_path = os.path.expanduser("~/anaconda3/envs/screen")
print(f"\n检测到Conda环境: {conda_env_path}")
custom_env = input(f"使用此环境? [Y/n] 或输入其他路径: ").strip()
if custom_env.lower() == 'n':
conda_env_path = input("请输入Conda环境完整路径: ").strip()
elif custom_env and custom_env.lower() != 'y':
conda_env_path = custom_env
# 目标阳离子
cation = input("🎯 请输入目标阳离子 [默认: Li]: ").strip() or "Li"

View File

@@ -15,35 +15,34 @@ from .progress import ProgressManager
class ExecutionMode(Enum):
"""执行模式"""
LOCAL_SINGLE = "local_single" # 单线程
LOCAL_MULTI = "local_multi" # 本地多进程
SLURM_SINGLE = "slurm_single" # SLURM单节点
SLURM_ARRAY = "slurm_array" # SLURM作业数组
LOCAL_SINGLE = "local_single"
LOCAL_MULTI = "local_multi"
SLURM_SINGLE = "slurm_single"
SLURM_ARRAY = "slurm_array"
@dataclass
class ResourceConfig:
"""资源配置"""
max_cores: int = 4 # 最大可用核数
cores_per_worker: int = 1 # 每个worker使用的核数
memory_per_core: str = "4G" # 每核内存
partition: str = "cpu" # SLURM分区
time_limit: str = "7-00:00:00" # 时间限制
max_cores: int = 4
cores_per_worker: int = 1
memory_per_core: str = "4G"
partition: str = "cpu"
time_limit: str = "7-00:00:00"
conda_env_path: str = "~/anaconda3/envs/screen" # 新增Conda环境路径
@property
def num_workers(self) -> int:
"""计算worker数量"""
return max(1, self.max_cores // self.cores_per_worker)
class ParallelScheduler:
"""并行调度器"""
# 根据任务复杂度推荐的核数配置
COMPLEXITY_CORES = {
'low': 1, # 简单IO操作
'medium': 2, # 结构解析+基础检查
'high': 4, # 复杂计算(扩胞等)
'low': 1,
'medium': 2,
'high': 4,
}
def __init__(self, resource_config: ResourceConfig = None):
@@ -59,6 +58,7 @@ class ParallelScheduler:
'has_slurm': False,
'slurm_partitions': [],
'available_nodes': 0,
'conda_prefix': os.environ.get('CONDA_PREFIX', ''),
}
# 检测SLURM
@@ -93,30 +93,29 @@ class ParallelScheduler:
def recommend_config(
num_tasks: int,
task_complexity: str = 'medium',
max_cores: int = None
max_cores: int = None,
conda_env_path: str = None
) -> ResourceConfig:
"""根据任务量和复杂度推荐配置"""
env = ParallelScheduler.detect_environment()
# 默认最大核数
if max_cores is None:
max_cores = min(env['total_cores'], 32) # 最多32核
max_cores = min(env['total_cores'], 32)
# 每个worker的核数
cores_per_worker = ParallelScheduler.COMPLEXITY_CORES.get(task_complexity, 2)
# 计算最优worker数
# 原则worker数 = min(任务数, 可用核数/每worker核数)
max_workers = max_cores // cores_per_worker
optimal_workers = min(num_tasks, max_workers)
# 重新分配核数
actual_cores = optimal_workers * cores_per_worker
# 自动检测Conda环境路径
if conda_env_path is None:
conda_env_path = env.get('conda_prefix', '~/anaconda3/envs/screen')
config = ResourceConfig(
max_cores=actual_cores,
cores_per_worker=cores_per_worker,
conda_env_path=conda_env_path,
)
return config
@@ -140,14 +139,12 @@ class ParallelScheduler:
print(f" 总使用核数: {self.config.max_cores}")
print(f"{'='*60}\n")
# 初始化进度管理器
self.progress_manager = ProgressManager(total, desc)
self.progress_manager.start()
results = []
if num_workers == 1:
# 单进程模式
for task in tasks:
try:
result = worker_func(task)
@@ -158,9 +155,7 @@ class ParallelScheduler:
self.progress_manager.update(success=False)
self.progress_manager.display()
else:
# 多进程模式
with Pool(processes=num_workers) as pool:
# 使用imap_unordered获取更好的性能
for result in pool.imap_unordered(worker_func, tasks, chunksize=10):
if result is not None:
results.append(result)
@@ -180,7 +175,15 @@ class ParallelScheduler:
output_dir: str,
job_name: str = "analysis"
) -> str:
"""生成SLURM作业脚本"""
"""生成SLURM作业脚本(修复版)"""
# 获取当前工作目录的绝对路径
submit_dir = os.getcwd()
# 转换为绝对路径
abs_tasks_file = os.path.abspath(tasks_file)
abs_worker_script = os.path.abspath(worker_script)
abs_output_dir = os.path.abspath(output_dir)
script = f"""#!/bin/bash
#SBATCH --job-name={job_name}
@@ -190,24 +193,74 @@ class ParallelScheduler:
#SBATCH --cpus-per-task={self.config.max_cores}
#SBATCH --mem-per-cpu={self.config.memory_per_core}
#SBATCH --time={self.config.time_limit}
#SBATCH --output={output_dir}/slurm_%j.log
#SBATCH --error={output_dir}/slurm_%j.err
#SBATCH --output={abs_output_dir}/slurm_%j.log
#SBATCH --error={abs_output_dir}/slurm_%j.err
# 环境设置
source $(conda info --base)/etc/profile.d/conda.sh
conda activate screen
# ============================================
# SLURM作业脚本 - 自动生成
# ============================================
# 设置Python路径
cd $SLURM_SUBMIT_DIR
export PYTHONPATH=$(pwd):$PYTHONPATH
echo "===== 作业信息 ====="
echo "作业ID: $SLURM_JOB_ID"
echo "节点: $SLURM_NODELIST"
echo "CPU数: $SLURM_CPUS_PER_TASK"
echo "开始时间: $(date)"
echo "===================="
# 运行分析
python {worker_script} \\
--tasks-file {tasks_file} \\
--output-dir {output_dir} \\
# ============ 环境初始化 ============
# 关键确保bashrc被加载
if [ -f ~/.bashrc ]; then
source ~/.bashrc
fi
# 初始化Conda
if [ -f ~/anaconda3/etc/profile.d/conda.sh ]; then
source ~/anaconda3/etc/profile.d/conda.sh
elif [ -f /opt/anaconda3/etc/profile.d/conda.sh ]; then
source /opt/anaconda3/etc/profile.d/conda.sh
else
echo "错误: 找不到conda.sh"
exit 1
fi
# 激活环境 (使用完整路径)
conda activate {self.config.conda_env_path}
# 验证环境
echo ""
echo "===== 环境检查 ====="
echo "Conda环境: $CONDA_DEFAULT_ENV"
echo "Python路径: $(which python)"
echo "Python版本: $(python --version 2>&1)"
python -c "import pymatgen; print(f'pymatgen版本: {{pymatgen.__version__}}')" 2>/dev/null || echo "警告: pymatgen未安装"
echo "===================="
echo ""
# 设置工作目录
cd {submit_dir}
export PYTHONPATH={submit_dir}:$PYTHONPATH
echo "工作目录: $(pwd)"
echo "PYTHONPATH: $PYTHONPATH"
echo ""
# ============ 运行分析 ============
echo "开始执行分析任务..."
python {abs_worker_script} \\
--tasks-file {abs_tasks_file} \\
--output-dir {abs_output_dir} \\
--num-workers {self.config.num_workers}
echo "Job completed at $(date)"
EXIT_CODE=$?
# ============ 完成 ============
echo ""
echo "===== 作业完成 ====="
echo "结束时间: $(date)"
echo "退出代码: $EXIT_CODE"
echo "===================="
exit $EXIT_CODE
"""
return script
@@ -223,6 +276,9 @@ echo "Job completed at $(date)"
os.chmod(script_path, 0o755)
# 打印脚本内容用于调试
print(f"\n生成的SLURM脚本保存到: {script_path}")
result = subprocess.run(
['sbatch', script_path],
capture_output=True, text=True
@@ -230,7 +286,6 @@ echo "Job completed at $(date)"
if result.returncode == 0:
job_id = result.stdout.strip().split()[-1]
print(f"作业已提交: {job_id}")
return job_id
else:
raise RuntimeError(f"提交失败: {result.stderr}")
raise RuntimeError(f"SLURM提交失败:\nstdout: {result.stdout}\nstderr: {result.stderr}")