预处理增加并行计算
This commit is contained in:
12
main.py
12
main.py
@@ -47,6 +47,18 @@ def get_user_input(env: dict):
|
||||
break
|
||||
print(f"❌ 路径不存在: {db_path}")
|
||||
|
||||
# 检测当前Conda环境路径
|
||||
conda_env_path = env.get('conda_prefix', '')
|
||||
if not conda_env_path:
|
||||
conda_env_path = os.path.expanduser("~/anaconda3/envs/screen")
|
||||
|
||||
print(f"\n检测到Conda环境: {conda_env_path}")
|
||||
custom_env = input(f"使用此环境? [Y/n] 或输入其他路径: ").strip()
|
||||
if custom_env.lower() == 'n':
|
||||
conda_env_path = input("请输入Conda环境完整路径: ").strip()
|
||||
elif custom_env and custom_env.lower() != 'y':
|
||||
conda_env_path = custom_env
|
||||
|
||||
# 目标阳离子
|
||||
cation = input("🎯 请输入目标阳离子 [默认: Li]: ").strip() or "Li"
|
||||
|
||||
|
||||
@@ -15,35 +15,34 @@ from .progress import ProgressManager
|
||||
|
||||
class ExecutionMode(Enum):
|
||||
"""执行模式"""
|
||||
LOCAL_SINGLE = "local_single" # 单线程
|
||||
LOCAL_MULTI = "local_multi" # 本地多进程
|
||||
SLURM_SINGLE = "slurm_single" # SLURM单节点
|
||||
SLURM_ARRAY = "slurm_array" # SLURM作业数组
|
||||
LOCAL_SINGLE = "local_single"
|
||||
LOCAL_MULTI = "local_multi"
|
||||
SLURM_SINGLE = "slurm_single"
|
||||
SLURM_ARRAY = "slurm_array"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResourceConfig:
|
||||
"""资源配置"""
|
||||
max_cores: int = 4 # 最大可用核数
|
||||
cores_per_worker: int = 1 # 每个worker使用的核数
|
||||
memory_per_core: str = "4G" # 每核内存
|
||||
partition: str = "cpu" # SLURM分区
|
||||
time_limit: str = "7-00:00:00" # 时间限制
|
||||
max_cores: int = 4
|
||||
cores_per_worker: int = 1
|
||||
memory_per_core: str = "4G"
|
||||
partition: str = "cpu"
|
||||
time_limit: str = "7-00:00:00"
|
||||
conda_env_path: str = "~/anaconda3/envs/screen" # 新增:Conda环境路径
|
||||
|
||||
@property
|
||||
def num_workers(self) -> int:
|
||||
"""计算worker数量"""
|
||||
return max(1, self.max_cores // self.cores_per_worker)
|
||||
|
||||
|
||||
class ParallelScheduler:
|
||||
"""并行调度器"""
|
||||
|
||||
# 根据任务复杂度推荐的核数配置
|
||||
COMPLEXITY_CORES = {
|
||||
'low': 1, # 简单IO操作
|
||||
'medium': 2, # 结构解析+基础检查
|
||||
'high': 4, # 复杂计算(扩胞等)
|
||||
'low': 1,
|
||||
'medium': 2,
|
||||
'high': 4,
|
||||
}
|
||||
|
||||
def __init__(self, resource_config: ResourceConfig = None):
|
||||
@@ -59,6 +58,7 @@ class ParallelScheduler:
|
||||
'has_slurm': False,
|
||||
'slurm_partitions': [],
|
||||
'available_nodes': 0,
|
||||
'conda_prefix': os.environ.get('CONDA_PREFIX', ''),
|
||||
}
|
||||
|
||||
# 检测SLURM
|
||||
@@ -93,30 +93,29 @@ class ParallelScheduler:
|
||||
def recommend_config(
|
||||
num_tasks: int,
|
||||
task_complexity: str = 'medium',
|
||||
max_cores: int = None
|
||||
max_cores: int = None,
|
||||
conda_env_path: str = None
|
||||
) -> ResourceConfig:
|
||||
"""根据任务量和复杂度推荐配置"""
|
||||
|
||||
env = ParallelScheduler.detect_environment()
|
||||
|
||||
# 默认最大核数
|
||||
if max_cores is None:
|
||||
max_cores = min(env['total_cores'], 32) # 最多32核
|
||||
max_cores = min(env['total_cores'], 32)
|
||||
|
||||
# 每个worker的核数
|
||||
cores_per_worker = ParallelScheduler.COMPLEXITY_CORES.get(task_complexity, 2)
|
||||
|
||||
# 计算最优worker数
|
||||
# 原则:worker数 = min(任务数, 可用核数/每worker核数)
|
||||
max_workers = max_cores // cores_per_worker
|
||||
optimal_workers = min(num_tasks, max_workers)
|
||||
|
||||
# 重新分配核数
|
||||
actual_cores = optimal_workers * cores_per_worker
|
||||
|
||||
# 自动检测Conda环境路径
|
||||
if conda_env_path is None:
|
||||
conda_env_path = env.get('conda_prefix', '~/anaconda3/envs/screen')
|
||||
|
||||
config = ResourceConfig(
|
||||
max_cores=actual_cores,
|
||||
cores_per_worker=cores_per_worker,
|
||||
conda_env_path=conda_env_path,
|
||||
)
|
||||
|
||||
return config
|
||||
@@ -140,14 +139,12 @@ class ParallelScheduler:
|
||||
print(f" 总使用核数: {self.config.max_cores}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# 初始化进度管理器
|
||||
self.progress_manager = ProgressManager(total, desc)
|
||||
self.progress_manager.start()
|
||||
|
||||
results = []
|
||||
|
||||
if num_workers == 1:
|
||||
# 单进程模式
|
||||
for task in tasks:
|
||||
try:
|
||||
result = worker_func(task)
|
||||
@@ -158,9 +155,7 @@ class ParallelScheduler:
|
||||
self.progress_manager.update(success=False)
|
||||
self.progress_manager.display()
|
||||
else:
|
||||
# 多进程模式
|
||||
with Pool(processes=num_workers) as pool:
|
||||
# 使用imap_unordered获取更好的性能
|
||||
for result in pool.imap_unordered(worker_func, tasks, chunksize=10):
|
||||
if result is not None:
|
||||
results.append(result)
|
||||
@@ -180,7 +175,15 @@ class ParallelScheduler:
|
||||
output_dir: str,
|
||||
job_name: str = "analysis"
|
||||
) -> str:
|
||||
"""生成SLURM作业脚本"""
|
||||
"""生成SLURM作业脚本(修复版)"""
|
||||
|
||||
# 获取当前工作目录的绝对路径
|
||||
submit_dir = os.getcwd()
|
||||
|
||||
# 转换为绝对路径
|
||||
abs_tasks_file = os.path.abspath(tasks_file)
|
||||
abs_worker_script = os.path.abspath(worker_script)
|
||||
abs_output_dir = os.path.abspath(output_dir)
|
||||
|
||||
script = f"""#!/bin/bash
|
||||
#SBATCH --job-name={job_name}
|
||||
@@ -190,24 +193,74 @@ class ParallelScheduler:
|
||||
#SBATCH --cpus-per-task={self.config.max_cores}
|
||||
#SBATCH --mem-per-cpu={self.config.memory_per_core}
|
||||
#SBATCH --time={self.config.time_limit}
|
||||
#SBATCH --output={output_dir}/slurm_%j.log
|
||||
#SBATCH --error={output_dir}/slurm_%j.err
|
||||
#SBATCH --output={abs_output_dir}/slurm_%j.log
|
||||
#SBATCH --error={abs_output_dir}/slurm_%j.err
|
||||
|
||||
# 环境设置
|
||||
source $(conda info --base)/etc/profile.d/conda.sh
|
||||
conda activate screen
|
||||
# ============================================
|
||||
# SLURM作业脚本 - 自动生成
|
||||
# ============================================
|
||||
|
||||
# 设置Python路径
|
||||
cd $SLURM_SUBMIT_DIR
|
||||
export PYTHONPATH=$(pwd):$PYTHONPATH
|
||||
echo "===== 作业信息 ====="
|
||||
echo "作业ID: $SLURM_JOB_ID"
|
||||
echo "节点: $SLURM_NODELIST"
|
||||
echo "CPU数: $SLURM_CPUS_PER_TASK"
|
||||
echo "开始时间: $(date)"
|
||||
echo "===================="
|
||||
|
||||
# 运行分析
|
||||
python {worker_script} \\
|
||||
--tasks-file {tasks_file} \\
|
||||
--output-dir {output_dir} \\
|
||||
# ============ 环境初始化 ============
|
||||
# 关键:确保bashrc被加载
|
||||
if [ -f ~/.bashrc ]; then
|
||||
source ~/.bashrc
|
||||
fi
|
||||
|
||||
# 初始化Conda
|
||||
if [ -f ~/anaconda3/etc/profile.d/conda.sh ]; then
|
||||
source ~/anaconda3/etc/profile.d/conda.sh
|
||||
elif [ -f /opt/anaconda3/etc/profile.d/conda.sh ]; then
|
||||
source /opt/anaconda3/etc/profile.d/conda.sh
|
||||
else
|
||||
echo "错误: 找不到conda.sh"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 激活环境 (使用完整路径)
|
||||
conda activate {self.config.conda_env_path}
|
||||
|
||||
# 验证环境
|
||||
echo ""
|
||||
echo "===== 环境检查 ====="
|
||||
echo "Conda环境: $CONDA_DEFAULT_ENV"
|
||||
echo "Python路径: $(which python)"
|
||||
echo "Python版本: $(python --version 2>&1)"
|
||||
python -c "import pymatgen; print(f'pymatgen版本: {{pymatgen.__version__}}')" 2>/dev/null || echo "警告: pymatgen未安装"
|
||||
echo "===================="
|
||||
echo ""
|
||||
|
||||
# 设置工作目录
|
||||
cd {submit_dir}
|
||||
export PYTHONPATH={submit_dir}:$PYTHONPATH
|
||||
|
||||
echo "工作目录: $(pwd)"
|
||||
echo "PYTHONPATH: $PYTHONPATH"
|
||||
echo ""
|
||||
|
||||
# ============ 运行分析 ============
|
||||
echo "开始执行分析任务..."
|
||||
python {abs_worker_script} \\
|
||||
--tasks-file {abs_tasks_file} \\
|
||||
--output-dir {abs_output_dir} \\
|
||||
--num-workers {self.config.num_workers}
|
||||
|
||||
echo "Job completed at $(date)"
|
||||
EXIT_CODE=$?
|
||||
|
||||
# ============ 完成 ============
|
||||
echo ""
|
||||
echo "===== 作业完成 ====="
|
||||
echo "结束时间: $(date)"
|
||||
echo "退出代码: $EXIT_CODE"
|
||||
echo "===================="
|
||||
|
||||
exit $EXIT_CODE
|
||||
"""
|
||||
return script
|
||||
|
||||
@@ -223,6 +276,9 @@ echo "Job completed at $(date)"
|
||||
|
||||
os.chmod(script_path, 0o755)
|
||||
|
||||
# 打印脚本内容用于调试
|
||||
print(f"\n生成的SLURM脚本保存到: {script_path}")
|
||||
|
||||
result = subprocess.run(
|
||||
['sbatch', script_path],
|
||||
capture_output=True, text=True
|
||||
@@ -230,7 +286,6 @@ echo "Job completed at $(date)"
|
||||
|
||||
if result.returncode == 0:
|
||||
job_id = result.stdout.strip().split()[-1]
|
||||
print(f"作业已提交: {job_id}")
|
||||
return job_id
|
||||
else:
|
||||
raise RuntimeError(f"提交失败: {result.stderr}")
|
||||
raise RuntimeError(f"SLURM提交失败:\nstdout: {result.stdout}\nstderr: {result.stderr}")
|
||||
Reference in New Issue
Block a user