预处理增加并行计算

This commit is contained in:
2025-12-14 15:53:11 +08:00
parent 1fee324c90
commit f27fd3e3ce
2 changed files with 123 additions and 56 deletions

12
main.py
View File

@@ -47,6 +47,18 @@ def get_user_input(env: dict):
break break
print(f"❌ 路径不存在: {db_path}") print(f"❌ 路径不存在: {db_path}")
# 检测当前Conda环境路径
conda_env_path = env.get('conda_prefix', '')
if not conda_env_path:
conda_env_path = os.path.expanduser("~/anaconda3/envs/screen")
print(f"\n检测到Conda环境: {conda_env_path}")
custom_env = input(f"使用此环境? [Y/n] 或输入其他路径: ").strip()
if custom_env.lower() == 'n':
conda_env_path = input("请输入Conda环境完整路径: ").strip()
elif custom_env and custom_env.lower() != 'y':
conda_env_path = custom_env
# 目标阳离子 # 目标阳离子
cation = input("🎯 请输入目标阳离子 [默认: Li]: ").strip() or "Li" cation = input("🎯 请输入目标阳离子 [默认: Li]: ").strip() or "Li"

View File

@@ -15,35 +15,34 @@ from .progress import ProgressManager
class ExecutionMode(Enum): class ExecutionMode(Enum):
"""执行模式""" """执行模式"""
LOCAL_SINGLE = "local_single" # 单线程 LOCAL_SINGLE = "local_single"
LOCAL_MULTI = "local_multi" # 本地多进程 LOCAL_MULTI = "local_multi"
SLURM_SINGLE = "slurm_single" # SLURM单节点 SLURM_SINGLE = "slurm_single"
SLURM_ARRAY = "slurm_array" # SLURM作业数组 SLURM_ARRAY = "slurm_array"
@dataclass @dataclass
class ResourceConfig: class ResourceConfig:
"""资源配置""" """资源配置"""
max_cores: int = 4 # 最大可用核数 max_cores: int = 4
cores_per_worker: int = 1 # 每个worker使用的核数 cores_per_worker: int = 1
memory_per_core: str = "4G" # 每核内存 memory_per_core: str = "4G"
partition: str = "cpu" # SLURM分区 partition: str = "cpu"
time_limit: str = "7-00:00:00" # 时间限制 time_limit: str = "7-00:00:00"
conda_env_path: str = "~/anaconda3/envs/screen" # 新增Conda环境路径
@property @property
def num_workers(self) -> int: def num_workers(self) -> int:
"""计算worker数量"""
return max(1, self.max_cores // self.cores_per_worker) return max(1, self.max_cores // self.cores_per_worker)
class ParallelScheduler: class ParallelScheduler:
"""并行调度器""" """并行调度器"""
# 根据任务复杂度推荐的核数配置
COMPLEXITY_CORES = { COMPLEXITY_CORES = {
'low': 1, # 简单IO操作 'low': 1,
'medium': 2, # 结构解析+基础检查 'medium': 2,
'high': 4, # 复杂计算(扩胞等) 'high': 4,
} }
def __init__(self, resource_config: ResourceConfig = None): def __init__(self, resource_config: ResourceConfig = None):
@@ -59,6 +58,7 @@ class ParallelScheduler:
'has_slurm': False, 'has_slurm': False,
'slurm_partitions': [], 'slurm_partitions': [],
'available_nodes': 0, 'available_nodes': 0,
'conda_prefix': os.environ.get('CONDA_PREFIX', ''),
} }
# 检测SLURM # 检测SLURM
@@ -91,63 +91,60 @@ class ParallelScheduler:
@staticmethod @staticmethod
def recommend_config( def recommend_config(
num_tasks: int, num_tasks: int,
task_complexity: str = 'medium', task_complexity: str = 'medium',
max_cores: int = None max_cores: int = None,
conda_env_path: str = None
) -> ResourceConfig: ) -> ResourceConfig:
"""根据任务量和复杂度推荐配置""" """根据任务量和复杂度推荐配置"""
env = ParallelScheduler.detect_environment() env = ParallelScheduler.detect_environment()
# 默认最大核数
if max_cores is None: if max_cores is None:
max_cores = min(env['total_cores'], 32) # 最多32核 max_cores = min(env['total_cores'], 32)
# 每个worker的核数
cores_per_worker = ParallelScheduler.COMPLEXITY_CORES.get(task_complexity, 2) cores_per_worker = ParallelScheduler.COMPLEXITY_CORES.get(task_complexity, 2)
# 计算最优worker数
# 原则worker数 = min(任务数, 可用核数/每worker核数)
max_workers = max_cores // cores_per_worker max_workers = max_cores // cores_per_worker
optimal_workers = min(num_tasks, max_workers) optimal_workers = min(num_tasks, max_workers)
# 重新分配核数
actual_cores = optimal_workers * cores_per_worker actual_cores = optimal_workers * cores_per_worker
# 自动检测Conda环境路径
if conda_env_path is None:
conda_env_path = env.get('conda_prefix', '~/anaconda3/envs/screen')
config = ResourceConfig( config = ResourceConfig(
max_cores=actual_cores, max_cores=actual_cores,
cores_per_worker=cores_per_worker, cores_per_worker=cores_per_worker,
conda_env_path=conda_env_path,
) )
return config return config
def run_local( def run_local(
self, self,
tasks: List[Any], tasks: List[Any],
worker_func: Callable, worker_func: Callable,
desc: str = "Processing" desc: str = "Processing"
) -> List[Any]: ) -> List[Any]:
"""本地多进程执行""" """本地多进程执行"""
num_workers = self.config.num_workers num_workers = self.config.num_workers
total = len(tasks) total = len(tasks)
print(f"\n{'=' * 60}") print(f"\n{'='*60}")
print(f"并行配置:") print(f"并行配置:")
print(f" 总任务数: {total}") print(f" 总任务数: {total}")
print(f" Worker数: {num_workers}") print(f" Worker数: {num_workers}")
print(f" 每Worker核数: {self.config.cores_per_worker}") print(f" 每Worker核数: {self.config.cores_per_worker}")
print(f" 总使用核数: {self.config.max_cores}") print(f" 总使用核数: {self.config.max_cores}")
print(f"{'=' * 60}\n") print(f"{'='*60}\n")
# 初始化进度管理器
self.progress_manager = ProgressManager(total, desc) self.progress_manager = ProgressManager(total, desc)
self.progress_manager.start() self.progress_manager.start()
results = [] results = []
if num_workers == 1: if num_workers == 1:
# 单进程模式
for task in tasks: for task in tasks:
try: try:
result = worker_func(task) result = worker_func(task)
@@ -158,9 +155,7 @@ class ParallelScheduler:
self.progress_manager.update(success=False) self.progress_manager.update(success=False)
self.progress_manager.display() self.progress_manager.display()
else: else:
# 多进程模式
with Pool(processes=num_workers) as pool: with Pool(processes=num_workers) as pool:
# 使用imap_unordered获取更好的性能
for result in pool.imap_unordered(worker_func, tasks, chunksize=10): for result in pool.imap_unordered(worker_func, tasks, chunksize=10):
if result is not None: if result is not None:
results.append(result) results.append(result)
@@ -174,13 +169,21 @@ class ParallelScheduler:
return results return results
def generate_slurm_script( def generate_slurm_script(
self, self,
tasks_file: str, tasks_file: str,
worker_script: str, worker_script: str,
output_dir: str, output_dir: str,
job_name: str = "analysis" job_name: str = "analysis"
) -> str: ) -> str:
"""生成SLURM作业脚本""" """生成SLURM作业脚本(修复版)"""
# 获取当前工作目录的绝对路径
submit_dir = os.getcwd()
# 转换为绝对路径
abs_tasks_file = os.path.abspath(tasks_file)
abs_worker_script = os.path.abspath(worker_script)
abs_output_dir = os.path.abspath(output_dir)
script = f"""#!/bin/bash script = f"""#!/bin/bash
#SBATCH --job-name={job_name} #SBATCH --job-name={job_name}
@@ -190,24 +193,74 @@ class ParallelScheduler:
#SBATCH --cpus-per-task={self.config.max_cores} #SBATCH --cpus-per-task={self.config.max_cores}
#SBATCH --mem-per-cpu={self.config.memory_per_core} #SBATCH --mem-per-cpu={self.config.memory_per_core}
#SBATCH --time={self.config.time_limit} #SBATCH --time={self.config.time_limit}
#SBATCH --output={output_dir}/slurm_%j.log #SBATCH --output={abs_output_dir}/slurm_%j.log
#SBATCH --error={output_dir}/slurm_%j.err #SBATCH --error={abs_output_dir}/slurm_%j.err
# 环境设置 # ============================================
source $(conda info --base)/etc/profile.d/conda.sh # SLURM作业脚本 - 自动生成
conda activate screen # ============================================
# 设置Python路径 echo "===== 作业信息 ====="
cd $SLURM_SUBMIT_DIR echo "作业ID: $SLURM_JOB_ID"
export PYTHONPATH=$(pwd):$PYTHONPATH echo "节点: $SLURM_NODELIST"
echo "CPU数: $SLURM_CPUS_PER_TASK"
echo "开始时间: $(date)"
echo "===================="
# 运行分析 # ============ 环境初始化 ============
python {worker_script} \\ # 关键确保bashrc被加载
--tasks-file {tasks_file} \\ if [ -f ~/.bashrc ]; then
--output-dir {output_dir} \\ source ~/.bashrc
fi
# 初始化Conda
if [ -f ~/anaconda3/etc/profile.d/conda.sh ]; then
source ~/anaconda3/etc/profile.d/conda.sh
elif [ -f /opt/anaconda3/etc/profile.d/conda.sh ]; then
source /opt/anaconda3/etc/profile.d/conda.sh
else
echo "错误: 找不到conda.sh"
exit 1
fi
# 激活环境 (使用完整路径)
conda activate {self.config.conda_env_path}
# 验证环境
echo ""
echo "===== 环境检查 ====="
echo "Conda环境: $CONDA_DEFAULT_ENV"
echo "Python路径: $(which python)"
echo "Python版本: $(python --version 2>&1)"
python -c "import pymatgen; print(f'pymatgen版本: {{pymatgen.__version__}}')" 2>/dev/null || echo "警告: pymatgen未安装"
echo "===================="
echo ""
# 设置工作目录
cd {submit_dir}
export PYTHONPATH={submit_dir}:$PYTHONPATH
echo "工作目录: $(pwd)"
echo "PYTHONPATH: $PYTHONPATH"
echo ""
# ============ 运行分析 ============
echo "开始执行分析任务..."
python {abs_worker_script} \\
--tasks-file {abs_tasks_file} \\
--output-dir {abs_output_dir} \\
--num-workers {self.config.num_workers} --num-workers {self.config.num_workers}
echo "Job completed at $(date)" EXIT_CODE=$?
# ============ 完成 ============
echo ""
echo "===== 作业完成 ====="
echo "结束时间: $(date)"
echo "退出代码: $EXIT_CODE"
echo "===================="
exit $EXIT_CODE
""" """
return script return script
@@ -223,6 +276,9 @@ echo "Job completed at $(date)"
os.chmod(script_path, 0o755) os.chmod(script_path, 0o755)
# 打印脚本内容用于调试
print(f"\n生成的SLURM脚本保存到: {script_path}")
result = subprocess.run( result = subprocess.run(
['sbatch', script_path], ['sbatch', script_path],
capture_output=True, text=True capture_output=True, text=True
@@ -230,7 +286,6 @@ echo "Job completed at $(date)"
if result.returncode == 0: if result.returncode == 0:
job_id = result.stdout.strip().split()[-1] job_id = result.stdout.strip().split()[-1]
print(f"作业已提交: {job_id}")
return job_id return job_id
else: else:
raise RuntimeError(f"提交失败: {result.stderr}") raise RuntimeError(f"SLURM提交失败:\nstdout: {result.stdout}\nstderr: {result.stderr}")