From f27fd3e3ce2273e0d293f425e344d6c736c0244f Mon Sep 17 00:00:00 2001 From: koko <1429659362@qq.com> Date: Sun, 14 Dec 2025 15:53:11 +0800 Subject: [PATCH] =?UTF-8?q?=E9=A2=84=E5=A4=84=E7=90=86=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E5=B9=B6=E8=A1=8C=E8=AE=A1=E7=AE=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 12 +++ src/core/scheduler.py | 167 ++++++++++++++++++++++++++++-------------- 2 files changed, 123 insertions(+), 56 deletions(-) diff --git a/main.py b/main.py index 01c083e..4fce51d 100644 --- a/main.py +++ b/main.py @@ -47,6 +47,18 @@ def get_user_input(env: dict): break print(f"❌ 路径不存在: {db_path}") + # 检测当前Conda环境路径 + conda_env_path = env.get('conda_prefix', '') + if not conda_env_path: + conda_env_path = os.path.expanduser("~/anaconda3/envs/screen") + + print(f"\n检测到Conda环境: {conda_env_path}") + custom_env = input(f"使用此环境? [Y/n] 或输入其他路径: ").strip() + if custom_env.lower() == 'n': + conda_env_path = input("请输入Conda环境完整路径: ").strip() + elif custom_env and custom_env.lower() != 'y': + conda_env_path = custom_env + # 目标阳离子 cation = input("🎯 请输入目标阳离子 [默认: Li]: ").strip() or "Li" diff --git a/src/core/scheduler.py b/src/core/scheduler.py index f584a4f..ce67664 100644 --- a/src/core/scheduler.py +++ b/src/core/scheduler.py @@ -15,35 +15,34 @@ from .progress import ProgressManager class ExecutionMode(Enum): """执行模式""" - LOCAL_SINGLE = "local_single" # 单线程 - LOCAL_MULTI = "local_multi" # 本地多进程 - SLURM_SINGLE = "slurm_single" # SLURM单节点 - SLURM_ARRAY = "slurm_array" # SLURM作业数组 + LOCAL_SINGLE = "local_single" + LOCAL_MULTI = "local_multi" + SLURM_SINGLE = "slurm_single" + SLURM_ARRAY = "slurm_array" @dataclass class ResourceConfig: """资源配置""" - max_cores: int = 4 # 最大可用核数 - cores_per_worker: int = 1 # 每个worker使用的核数 - memory_per_core: str = "4G" # 每核内存 - partition: str = "cpu" # SLURM分区 - time_limit: str = "7-00:00:00" # 时间限制 + max_cores: int = 4 + cores_per_worker: int = 1 + memory_per_core: str = "4G" + partition: str = "cpu" + time_limit: str = "7-00:00:00" + conda_env_path: str = "~/anaconda3/envs/screen" # 新增:Conda环境路径 @property def num_workers(self) -> int: - """计算worker数量""" return max(1, self.max_cores // self.cores_per_worker) class ParallelScheduler: """并行调度器""" - # 根据任务复杂度推荐的核数配置 COMPLEXITY_CORES = { - 'low': 1, # 简单IO操作 - 'medium': 2, # 结构解析+基础检查 - 'high': 4, # 复杂计算(扩胞等) + 'low': 1, + 'medium': 2, + 'high': 4, } def __init__(self, resource_config: ResourceConfig = None): @@ -59,6 +58,7 @@ class ParallelScheduler: 'has_slurm': False, 'slurm_partitions': [], 'available_nodes': 0, + 'conda_prefix': os.environ.get('CONDA_PREFIX', ''), } # 检测SLURM @@ -91,63 +91,60 @@ class ParallelScheduler: @staticmethod def recommend_config( - num_tasks: int, - task_complexity: str = 'medium', - max_cores: int = None + num_tasks: int, + task_complexity: str = 'medium', + max_cores: int = None, + conda_env_path: str = None ) -> ResourceConfig: """根据任务量和复杂度推荐配置""" env = ParallelScheduler.detect_environment() - # 默认最大核数 if max_cores is None: - max_cores = min(env['total_cores'], 32) # 最多32核 + max_cores = min(env['total_cores'], 32) - # 每个worker的核数 cores_per_worker = ParallelScheduler.COMPLEXITY_CORES.get(task_complexity, 2) - - # 计算最优worker数 - # 原则:worker数 = min(任务数, 可用核数/每worker核数) max_workers = max_cores // cores_per_worker optimal_workers = min(num_tasks, max_workers) - - # 重新分配核数 actual_cores = optimal_workers * cores_per_worker + # 自动检测Conda环境路径 + if conda_env_path is None: + conda_env_path = env.get('conda_prefix', '~/anaconda3/envs/screen') + config = ResourceConfig( max_cores=actual_cores, cores_per_worker=cores_per_worker, + conda_env_path=conda_env_path, ) return config def run_local( - self, - tasks: List[Any], - worker_func: Callable, - desc: str = "Processing" + self, + tasks: List[Any], + worker_func: Callable, + desc: str = "Processing" ) -> List[Any]: """本地多进程执行""" num_workers = self.config.num_workers total = len(tasks) - print(f"\n{'=' * 60}") + print(f"\n{'='*60}") print(f"并行配置:") print(f" 总任务数: {total}") print(f" Worker数: {num_workers}") print(f" 每Worker核数: {self.config.cores_per_worker}") print(f" 总使用核数: {self.config.max_cores}") - print(f"{'=' * 60}\n") + print(f"{'='*60}\n") - # 初始化进度管理器 self.progress_manager = ProgressManager(total, desc) self.progress_manager.start() results = [] if num_workers == 1: - # 单进程模式 for task in tasks: try: result = worker_func(task) @@ -158,9 +155,7 @@ class ParallelScheduler: self.progress_manager.update(success=False) self.progress_manager.display() else: - # 多进程模式 with Pool(processes=num_workers) as pool: - # 使用imap_unordered获取更好的性能 for result in pool.imap_unordered(worker_func, tasks, chunksize=10): if result is not None: results.append(result) @@ -174,13 +169,21 @@ class ParallelScheduler: return results def generate_slurm_script( - self, - tasks_file: str, - worker_script: str, - output_dir: str, - job_name: str = "analysis" + self, + tasks_file: str, + worker_script: str, + output_dir: str, + job_name: str = "analysis" ) -> str: - """生成SLURM作业脚本""" + """生成SLURM作业脚本(修复版)""" + + # 获取当前工作目录的绝对路径 + submit_dir = os.getcwd() + + # 转换为绝对路径 + abs_tasks_file = os.path.abspath(tasks_file) + abs_worker_script = os.path.abspath(worker_script) + abs_output_dir = os.path.abspath(output_dir) script = f"""#!/bin/bash #SBATCH --job-name={job_name} @@ -190,24 +193,74 @@ class ParallelScheduler: #SBATCH --cpus-per-task={self.config.max_cores} #SBATCH --mem-per-cpu={self.config.memory_per_core} #SBATCH --time={self.config.time_limit} -#SBATCH --output={output_dir}/slurm_%j.log -#SBATCH --error={output_dir}/slurm_%j.err +#SBATCH --output={abs_output_dir}/slurm_%j.log +#SBATCH --error={abs_output_dir}/slurm_%j.err -# 环境设置 -source $(conda info --base)/etc/profile.d/conda.sh -conda activate screen +# ============================================ +# SLURM作业脚本 - 自动生成 +# ============================================ -# 设置Python路径 -cd $SLURM_SUBMIT_DIR -export PYTHONPATH=$(pwd):$PYTHONPATH +echo "===== 作业信息 =====" +echo "作业ID: $SLURM_JOB_ID" +echo "节点: $SLURM_NODELIST" +echo "CPU数: $SLURM_CPUS_PER_TASK" +echo "开始时间: $(date)" +echo "====================" -# 运行分析 -python {worker_script} \\ - --tasks-file {tasks_file} \\ - --output-dir {output_dir} \\ +# ============ 环境初始化 ============ +# 关键:确保bashrc被加载 +if [ -f ~/.bashrc ]; then + source ~/.bashrc +fi + +# 初始化Conda +if [ -f ~/anaconda3/etc/profile.d/conda.sh ]; then + source ~/anaconda3/etc/profile.d/conda.sh +elif [ -f /opt/anaconda3/etc/profile.d/conda.sh ]; then + source /opt/anaconda3/etc/profile.d/conda.sh +else + echo "错误: 找不到conda.sh" + exit 1 +fi + +# 激活环境 (使用完整路径) +conda activate {self.config.conda_env_path} + +# 验证环境 +echo "" +echo "===== 环境检查 =====" +echo "Conda环境: $CONDA_DEFAULT_ENV" +echo "Python路径: $(which python)" +echo "Python版本: $(python --version 2>&1)" +python -c "import pymatgen; print(f'pymatgen版本: {{pymatgen.__version__}}')" 2>/dev/null || echo "警告: pymatgen未安装" +echo "====================" +echo "" + +# 设置工作目录 +cd {submit_dir} +export PYTHONPATH={submit_dir}:$PYTHONPATH + +echo "工作目录: $(pwd)" +echo "PYTHONPATH: $PYTHONPATH" +echo "" + +# ============ 运行分析 ============ +echo "开始执行分析任务..." +python {abs_worker_script} \\ + --tasks-file {abs_tasks_file} \\ + --output-dir {abs_output_dir} \\ --num-workers {self.config.num_workers} -echo "Job completed at $(date)" +EXIT_CODE=$? + +# ============ 完成 ============ +echo "" +echo "===== 作业完成 =====" +echo "结束时间: $(date)" +echo "退出代码: $EXIT_CODE" +echo "====================" + +exit $EXIT_CODE """ return script @@ -223,6 +276,9 @@ echo "Job completed at $(date)" os.chmod(script_path, 0o755) + # 打印脚本内容用于调试 + print(f"\n生成的SLURM脚本保存到: {script_path}") + result = subprocess.run( ['sbatch', script_path], capture_output=True, text=True @@ -230,7 +286,6 @@ echo "Job completed at $(date)" if result.returncode == 0: job_id = result.stdout.strip().split()[-1] - print(f"作业已提交: {job_id}") return job_id else: - raise RuntimeError(f"提交失败: {result.stderr}") \ No newline at end of file + raise RuntimeError(f"SLURM提交失败:\nstdout: {result.stdout}\nstderr: {result.stderr}") \ No newline at end of file