预处理增加并行计算
This commit is contained in:
12
main.py
12
main.py
@@ -47,6 +47,18 @@ def get_user_input(env: dict):
|
|||||||
break
|
break
|
||||||
print(f"❌ 路径不存在: {db_path}")
|
print(f"❌ 路径不存在: {db_path}")
|
||||||
|
|
||||||
|
# 检测当前Conda环境路径
|
||||||
|
conda_env_path = env.get('conda_prefix', '')
|
||||||
|
if not conda_env_path:
|
||||||
|
conda_env_path = os.path.expanduser("~/anaconda3/envs/screen")
|
||||||
|
|
||||||
|
print(f"\n检测到Conda环境: {conda_env_path}")
|
||||||
|
custom_env = input(f"使用此环境? [Y/n] 或输入其他路径: ").strip()
|
||||||
|
if custom_env.lower() == 'n':
|
||||||
|
conda_env_path = input("请输入Conda环境完整路径: ").strip()
|
||||||
|
elif custom_env and custom_env.lower() != 'y':
|
||||||
|
conda_env_path = custom_env
|
||||||
|
|
||||||
# 目标阳离子
|
# 目标阳离子
|
||||||
cation = input("🎯 请输入目标阳离子 [默认: Li]: ").strip() or "Li"
|
cation = input("🎯 请输入目标阳离子 [默认: Li]: ").strip() or "Li"
|
||||||
|
|
||||||
|
|||||||
@@ -15,35 +15,34 @@ from .progress import ProgressManager
|
|||||||
|
|
||||||
class ExecutionMode(Enum):
|
class ExecutionMode(Enum):
|
||||||
"""执行模式"""
|
"""执行模式"""
|
||||||
LOCAL_SINGLE = "local_single" # 单线程
|
LOCAL_SINGLE = "local_single"
|
||||||
LOCAL_MULTI = "local_multi" # 本地多进程
|
LOCAL_MULTI = "local_multi"
|
||||||
SLURM_SINGLE = "slurm_single" # SLURM单节点
|
SLURM_SINGLE = "slurm_single"
|
||||||
SLURM_ARRAY = "slurm_array" # SLURM作业数组
|
SLURM_ARRAY = "slurm_array"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ResourceConfig:
|
class ResourceConfig:
|
||||||
"""资源配置"""
|
"""资源配置"""
|
||||||
max_cores: int = 4 # 最大可用核数
|
max_cores: int = 4
|
||||||
cores_per_worker: int = 1 # 每个worker使用的核数
|
cores_per_worker: int = 1
|
||||||
memory_per_core: str = "4G" # 每核内存
|
memory_per_core: str = "4G"
|
||||||
partition: str = "cpu" # SLURM分区
|
partition: str = "cpu"
|
||||||
time_limit: str = "7-00:00:00" # 时间限制
|
time_limit: str = "7-00:00:00"
|
||||||
|
conda_env_path: str = "~/anaconda3/envs/screen" # 新增:Conda环境路径
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_workers(self) -> int:
|
def num_workers(self) -> int:
|
||||||
"""计算worker数量"""
|
|
||||||
return max(1, self.max_cores // self.cores_per_worker)
|
return max(1, self.max_cores // self.cores_per_worker)
|
||||||
|
|
||||||
|
|
||||||
class ParallelScheduler:
|
class ParallelScheduler:
|
||||||
"""并行调度器"""
|
"""并行调度器"""
|
||||||
|
|
||||||
# 根据任务复杂度推荐的核数配置
|
|
||||||
COMPLEXITY_CORES = {
|
COMPLEXITY_CORES = {
|
||||||
'low': 1, # 简单IO操作
|
'low': 1,
|
||||||
'medium': 2, # 结构解析+基础检查
|
'medium': 2,
|
||||||
'high': 4, # 复杂计算(扩胞等)
|
'high': 4,
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, resource_config: ResourceConfig = None):
|
def __init__(self, resource_config: ResourceConfig = None):
|
||||||
@@ -59,6 +58,7 @@ class ParallelScheduler:
|
|||||||
'has_slurm': False,
|
'has_slurm': False,
|
||||||
'slurm_partitions': [],
|
'slurm_partitions': [],
|
||||||
'available_nodes': 0,
|
'available_nodes': 0,
|
||||||
|
'conda_prefix': os.environ.get('CONDA_PREFIX', ''),
|
||||||
}
|
}
|
||||||
|
|
||||||
# 检测SLURM
|
# 检测SLURM
|
||||||
@@ -93,30 +93,29 @@ class ParallelScheduler:
|
|||||||
def recommend_config(
|
def recommend_config(
|
||||||
num_tasks: int,
|
num_tasks: int,
|
||||||
task_complexity: str = 'medium',
|
task_complexity: str = 'medium',
|
||||||
max_cores: int = None
|
max_cores: int = None,
|
||||||
|
conda_env_path: str = None
|
||||||
) -> ResourceConfig:
|
) -> ResourceConfig:
|
||||||
"""根据任务量和复杂度推荐配置"""
|
"""根据任务量和复杂度推荐配置"""
|
||||||
|
|
||||||
env = ParallelScheduler.detect_environment()
|
env = ParallelScheduler.detect_environment()
|
||||||
|
|
||||||
# 默认最大核数
|
|
||||||
if max_cores is None:
|
if max_cores is None:
|
||||||
max_cores = min(env['total_cores'], 32) # 最多32核
|
max_cores = min(env['total_cores'], 32)
|
||||||
|
|
||||||
# 每个worker的核数
|
|
||||||
cores_per_worker = ParallelScheduler.COMPLEXITY_CORES.get(task_complexity, 2)
|
cores_per_worker = ParallelScheduler.COMPLEXITY_CORES.get(task_complexity, 2)
|
||||||
|
|
||||||
# 计算最优worker数
|
|
||||||
# 原则:worker数 = min(任务数, 可用核数/每worker核数)
|
|
||||||
max_workers = max_cores // cores_per_worker
|
max_workers = max_cores // cores_per_worker
|
||||||
optimal_workers = min(num_tasks, max_workers)
|
optimal_workers = min(num_tasks, max_workers)
|
||||||
|
|
||||||
# 重新分配核数
|
|
||||||
actual_cores = optimal_workers * cores_per_worker
|
actual_cores = optimal_workers * cores_per_worker
|
||||||
|
|
||||||
|
# 自动检测Conda环境路径
|
||||||
|
if conda_env_path is None:
|
||||||
|
conda_env_path = env.get('conda_prefix', '~/anaconda3/envs/screen')
|
||||||
|
|
||||||
config = ResourceConfig(
|
config = ResourceConfig(
|
||||||
max_cores=actual_cores,
|
max_cores=actual_cores,
|
||||||
cores_per_worker=cores_per_worker,
|
cores_per_worker=cores_per_worker,
|
||||||
|
conda_env_path=conda_env_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
return config
|
return config
|
||||||
@@ -140,14 +139,12 @@ class ParallelScheduler:
|
|||||||
print(f" 总使用核数: {self.config.max_cores}")
|
print(f" 总使用核数: {self.config.max_cores}")
|
||||||
print(f"{'='*60}\n")
|
print(f"{'='*60}\n")
|
||||||
|
|
||||||
# 初始化进度管理器
|
|
||||||
self.progress_manager = ProgressManager(total, desc)
|
self.progress_manager = ProgressManager(total, desc)
|
||||||
self.progress_manager.start()
|
self.progress_manager.start()
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
if num_workers == 1:
|
if num_workers == 1:
|
||||||
# 单进程模式
|
|
||||||
for task in tasks:
|
for task in tasks:
|
||||||
try:
|
try:
|
||||||
result = worker_func(task)
|
result = worker_func(task)
|
||||||
@@ -158,9 +155,7 @@ class ParallelScheduler:
|
|||||||
self.progress_manager.update(success=False)
|
self.progress_manager.update(success=False)
|
||||||
self.progress_manager.display()
|
self.progress_manager.display()
|
||||||
else:
|
else:
|
||||||
# 多进程模式
|
|
||||||
with Pool(processes=num_workers) as pool:
|
with Pool(processes=num_workers) as pool:
|
||||||
# 使用imap_unordered获取更好的性能
|
|
||||||
for result in pool.imap_unordered(worker_func, tasks, chunksize=10):
|
for result in pool.imap_unordered(worker_func, tasks, chunksize=10):
|
||||||
if result is not None:
|
if result is not None:
|
||||||
results.append(result)
|
results.append(result)
|
||||||
@@ -180,7 +175,15 @@ class ParallelScheduler:
|
|||||||
output_dir: str,
|
output_dir: str,
|
||||||
job_name: str = "analysis"
|
job_name: str = "analysis"
|
||||||
) -> str:
|
) -> str:
|
||||||
"""生成SLURM作业脚本"""
|
"""生成SLURM作业脚本(修复版)"""
|
||||||
|
|
||||||
|
# 获取当前工作目录的绝对路径
|
||||||
|
submit_dir = os.getcwd()
|
||||||
|
|
||||||
|
# 转换为绝对路径
|
||||||
|
abs_tasks_file = os.path.abspath(tasks_file)
|
||||||
|
abs_worker_script = os.path.abspath(worker_script)
|
||||||
|
abs_output_dir = os.path.abspath(output_dir)
|
||||||
|
|
||||||
script = f"""#!/bin/bash
|
script = f"""#!/bin/bash
|
||||||
#SBATCH --job-name={job_name}
|
#SBATCH --job-name={job_name}
|
||||||
@@ -190,24 +193,74 @@ class ParallelScheduler:
|
|||||||
#SBATCH --cpus-per-task={self.config.max_cores}
|
#SBATCH --cpus-per-task={self.config.max_cores}
|
||||||
#SBATCH --mem-per-cpu={self.config.memory_per_core}
|
#SBATCH --mem-per-cpu={self.config.memory_per_core}
|
||||||
#SBATCH --time={self.config.time_limit}
|
#SBATCH --time={self.config.time_limit}
|
||||||
#SBATCH --output={output_dir}/slurm_%j.log
|
#SBATCH --output={abs_output_dir}/slurm_%j.log
|
||||||
#SBATCH --error={output_dir}/slurm_%j.err
|
#SBATCH --error={abs_output_dir}/slurm_%j.err
|
||||||
|
|
||||||
# 环境设置
|
# ============================================
|
||||||
source $(conda info --base)/etc/profile.d/conda.sh
|
# SLURM作业脚本 - 自动生成
|
||||||
conda activate screen
|
# ============================================
|
||||||
|
|
||||||
# 设置Python路径
|
echo "===== 作业信息 ====="
|
||||||
cd $SLURM_SUBMIT_DIR
|
echo "作业ID: $SLURM_JOB_ID"
|
||||||
export PYTHONPATH=$(pwd):$PYTHONPATH
|
echo "节点: $SLURM_NODELIST"
|
||||||
|
echo "CPU数: $SLURM_CPUS_PER_TASK"
|
||||||
|
echo "开始时间: $(date)"
|
||||||
|
echo "===================="
|
||||||
|
|
||||||
# 运行分析
|
# ============ 环境初始化 ============
|
||||||
python {worker_script} \\
|
# 关键:确保bashrc被加载
|
||||||
--tasks-file {tasks_file} \\
|
if [ -f ~/.bashrc ]; then
|
||||||
--output-dir {output_dir} \\
|
source ~/.bashrc
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 初始化Conda
|
||||||
|
if [ -f ~/anaconda3/etc/profile.d/conda.sh ]; then
|
||||||
|
source ~/anaconda3/etc/profile.d/conda.sh
|
||||||
|
elif [ -f /opt/anaconda3/etc/profile.d/conda.sh ]; then
|
||||||
|
source /opt/anaconda3/etc/profile.d/conda.sh
|
||||||
|
else
|
||||||
|
echo "错误: 找不到conda.sh"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 激活环境 (使用完整路径)
|
||||||
|
conda activate {self.config.conda_env_path}
|
||||||
|
|
||||||
|
# 验证环境
|
||||||
|
echo ""
|
||||||
|
echo "===== 环境检查 ====="
|
||||||
|
echo "Conda环境: $CONDA_DEFAULT_ENV"
|
||||||
|
echo "Python路径: $(which python)"
|
||||||
|
echo "Python版本: $(python --version 2>&1)"
|
||||||
|
python -c "import pymatgen; print(f'pymatgen版本: {{pymatgen.__version__}}')" 2>/dev/null || echo "警告: pymatgen未安装"
|
||||||
|
echo "===================="
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# 设置工作目录
|
||||||
|
cd {submit_dir}
|
||||||
|
export PYTHONPATH={submit_dir}:$PYTHONPATH
|
||||||
|
|
||||||
|
echo "工作目录: $(pwd)"
|
||||||
|
echo "PYTHONPATH: $PYTHONPATH"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# ============ 运行分析 ============
|
||||||
|
echo "开始执行分析任务..."
|
||||||
|
python {abs_worker_script} \\
|
||||||
|
--tasks-file {abs_tasks_file} \\
|
||||||
|
--output-dir {abs_output_dir} \\
|
||||||
--num-workers {self.config.num_workers}
|
--num-workers {self.config.num_workers}
|
||||||
|
|
||||||
echo "Job completed at $(date)"
|
EXIT_CODE=$?
|
||||||
|
|
||||||
|
# ============ 完成 ============
|
||||||
|
echo ""
|
||||||
|
echo "===== 作业完成 ====="
|
||||||
|
echo "结束时间: $(date)"
|
||||||
|
echo "退出代码: $EXIT_CODE"
|
||||||
|
echo "===================="
|
||||||
|
|
||||||
|
exit $EXIT_CODE
|
||||||
"""
|
"""
|
||||||
return script
|
return script
|
||||||
|
|
||||||
@@ -223,6 +276,9 @@ echo "Job completed at $(date)"
|
|||||||
|
|
||||||
os.chmod(script_path, 0o755)
|
os.chmod(script_path, 0o755)
|
||||||
|
|
||||||
|
# 打印脚本内容用于调试
|
||||||
|
print(f"\n生成的SLURM脚本保存到: {script_path}")
|
||||||
|
|
||||||
result = subprocess.run(
|
result = subprocess.run(
|
||||||
['sbatch', script_path],
|
['sbatch', script_path],
|
||||||
capture_output=True, text=True
|
capture_output=True, text=True
|
||||||
@@ -230,7 +286,6 @@ echo "Job completed at $(date)"
|
|||||||
|
|
||||||
if result.returncode == 0:
|
if result.returncode == 0:
|
||||||
job_id = result.stdout.strip().split()[-1]
|
job_id = result.stdout.strip().split()[-1]
|
||||||
print(f"作业已提交: {job_id}")
|
|
||||||
return job_id
|
return job_id
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"提交失败: {result.stderr}")
|
raise RuntimeError(f"SLURM提交失败:\nstdout: {result.stdout}\nstderr: {result.stderr}")
|
||||||
Reference in New Issue
Block a user