diff --git a/config/machine.yaml b/config/machine.yaml new file mode 100644 index 0000000..4cd5174 --- /dev/null +++ b/config/machine.yaml @@ -0,0 +1,53 @@ +# config/machine.yaml + +# 当前使用的计算系统配置名 +current_system: "interactive_gpu" + +systems: + # --- 配置 1: 交互式 GPU 环境 (当前使用) --- + # 场景: 你已经用 srun/tmux 申请到了资源,直接运行命令即可 + interactive_gpu: + type: "local" # local 表示直接运行 subprocess,不提交 sbatch + + # 路径配置 + gpumdkit_root: "/cluster/home/koko125/tool/GPUMDkit" + + tools: + # 1. GPUMD 配置 + gpumd: + command: "gpumd" + # 运行前需要 source 的环境脚本 + env_setup: "" + gpu_id: 0 + + # 2. NEP 配置 (同上) + nep: + command: "nep" + env_setup: "" + gpu_id: 0 + + # 3. VASP (GPU 版) 配置 + vasp: + # 假设是 GPU 版本,可能不需要 mpirun 或者只需要少量核 + command: "mpirun -np 1 vasp_std" + env_setup: "" + # 即使是 local 模式,有时也需要指定并行度 + n_procs: 1 + + # --- 配置 2: VASP CPU 集群模式 (预留,未来使用) --- + # 场景: 需要生成 submit.slurm 并 sbatch 提交 + slurm_cpu_cluster: + type: "slurm" + + gpumdkit_root: "/cluster/home/koko125/tool/GPUMDkit" + + tools: + vasp: + command: "mpirun -np 4 vasp_std" + env_setup: "module load vasp/6.3-cpu" + + # Slurm 头部参数 + slurm_header: + partition: "cpu_long" + ntasks_per_node: 64 + time: "24:00:00" \ No newline at end of file diff --git a/config/param.yaml b/config/param.yaml index af9eefa..4fbad98 100644 --- a/config/param.yaml +++ b/config/param.yaml @@ -1,70 +1,46 @@ # config/param.yaml -# --- 1. 环境与路径配置 --- -env: - # 可执行文件绝对路径 - vasp_std: "mpirun -np 1 /cluster/home/koko125/vasp/bin_gpu/vasp_std" - gpumd: "/cluster/home/koko125/tool/GPUMD/src/gpumd" - nep: "/cluster/home/koko125/tool/GPUMD/src/nep" - - # GPUMDKit 脚本库根目录 - gpumdkit_root: "/cluster/home/koko125/tool/GPUMDkit" - - # 【修改点】HPC 作业提交配置 (用于填充 submit.slurm 模板) - # 这些变量会被自动替换到 .sh 脚本头部 -# slurm_config: -# partition: "v100" # 队列分区名 -# account: "def-user" # 账户名 (如果有) -# gpu_per_node: 1 # 每节点 GPU 数 -# ntasks_per_node: 32 # 每节点 CPU 核数 -# time_limit: "24:00:00" # 墙钟时间限制 - -# --- 2. 流程控制 --- -# 阶段代号定义 (对应 modules 下的 Python 文件) +# --- 1. 流程控制 --- stages_def: - p: "preheat" # 00.md/preheat - m: "md" # 00.md/md - s: "select" # 01.select - d: "scf" # 02.scf - t: "train" # 03.train - pr: "predict" # 04.predict (新增:用于性质预测) - o: "output" # 05.output (始终默认执行:整理报告) + p: "preheat" + m: "md" + s: "select" + d: "scf" + t: "train" + pr: "predict" + o: "output" -# 自定义流程调度 -# 注意:'o' (output) 不需要显式写在这里,代码逻辑会强制每轮最后执行它 -schedule: - # 第1轮: 跑完训练,不做预测,看一眼结果 - 1: ["p", "m", "s", "d", "t"] - - # 第2轮: 跑完训练,加入预测步骤 (计算电导/扩散等) - 2: ["p", "m", "s", "d", "t", "pr"] - -# 默认流程 (如果没有定义轮次) +# 默认流程 default_workflow: ["p", "m", "s", "d", "t", "pr"] -# --- 3. 容错与通知 --- +# 自定义调度 +schedule: + 1: ["p", "m", "s", "d", "t", "o"] + +# --- 2. 容错与通知 --- control: - max_retries: 3 # 任务失败自动重启次数 - check_interval: 60 # 状态检查间隔 (秒) + max_retries: 3 + check_interval: 60 notification: enable_log: true log_file: "./logs/sys_runtime.log" - enable_hook: true hook_script: "python ./hooks/send_alert.py" alert_events: ["fail", "finish"] -# --- 4. 模块参数 --- +# --- 3. 各模块具体的物理/算法参数 --- params: preheat: temp: 300 steps: 10000 + # 这里不需要指定 gpumd 路径,只需要指定物理量 + select: target_min: 60 target_max: 120 init_threshold: 0.01 - predict: - # 预测阶段需要的参数,比如计算电导率的温度范围 - temperatures: [300, 400, 500] - script_path: "scripts/calc_conductivity.py" # 具体的计算脚本 \ No newline at end of file + + scf: + # 比如指定用 machine.yaml 里的哪个 tool 配置 + tool_key: "vasp" \ No newline at end of file diff --git a/nep_auto/driver.py b/nep_auto/driver.py index f6dfcb8..be455d4 100644 --- a/nep_auto/driver.py +++ b/nep_auto/driver.py @@ -10,10 +10,14 @@ class NEPDriver: self.logger = logging.getLogger("NEP_Auto") self.root = Path(".") - # 1. 加载配置 + # 1. 加载所有配置 self.config_sys = self._load_yaml("config/system.yaml") self.config_param = self._load_yaml("config/param.yaml") + # 【新增】加载 machine 配置 + self.config_machine = self._load_yaml("config/machine.yaml") + self.logger.info(f"项目名称: {self.config_sys.get('project_name')}") + self.logger.info(f"计算环境: {self.config_machine.get('current_system')}") # 2. 初始化状态管理器 self.status = StatusManager(self.root / "workspace") diff --git a/nep_auto/modules/base_module.py b/nep_auto/modules/base_module.py index e69de29..9c338a8 100644 --- a/nep_auto/modules/base_module.py +++ b/nep_auto/modules/base_module.py @@ -0,0 +1,77 @@ +import os +import shutil +import logging +from pathlib import Path +from nep_auto.utils.runner import CommandRunner + + +class BaseModule: + def __init__(self, driver, iter_id): + """ + :param driver: NEPDriver 实例,包含所有配置 + :param iter_id: 当前轮次 (int) + """ + self.driver = driver + self.config_sys = driver.config_sys + self.config_param = driver.config_param + self.machine_config = driver.config_machine['systems'][driver.config_machine['current_system']] + + self.iter_id = iter_id + self.iter_name = f"iter_{iter_id:03d}" + self.logger = logging.getLogger("NEP_Auto") + + # 初始化运行器 + self.runner = CommandRunner(self.machine_config) + + # 定义路径 + self.root = Path(driver.root) / "workspace" + self.iter_dir = self.root / self.iter_name + self.output_dir = self.iter_dir / "05.output" # 公共输出区 + + def get_work_dir(self): + """需由子类实现:返回当前模块的具体工作目录""" + raise NotImplementedError + + def initialize(self): + """通用初始化:创建目录,复制通用文件""" + work_dir = self.get_work_dir() + if not work_dir.exists(): + work_dir.mkdir(parents=True, exist_ok=True) + self.logger.debug(f"📁 Created dir: {work_dir}") + + # 确保公共输出目录存在 + if not self.output_dir.exists(): + self.output_dir.mkdir(parents=True, exist_ok=True) + + def run(self): + """核心逻辑入口,子类必须实现""" + raise NotImplementedError + + def check_done(self): + """检查任务是否完成,子类必须实现""" + raise NotImplementedError + + # --- 通用工具方法 --- + + def copy_template(self, template_name, target_name=None): + """从 template 目录复制文件""" + if target_name is None: + target_name = template_name + + # 根据模块类型寻找模板目录 (需要在子类定义 self.template_subdir) + src = Path("template") / getattr(self, "template_subdir", "common") / template_name + dst = self.get_work_dir() / target_name + + if src.exists(): + shutil.copy(src, dst) + # self.logger.debug(f"📄 Copied {template_name} -> {dst}") + else: + self.logger.warning(f"⚠️ Template not found: {src}") + + def link_file(self, src_path, dst_name): + """创建软链接""" + src = Path(src_path).resolve() + dst = self.get_work_dir() / dst_name + if dst.exists(): + dst.unlink() + os.symlink(src, dst) \ No newline at end of file diff --git a/nep_auto/utils/runner.py b/nep_auto/utils/runner.py index e69de29..996c144 100644 --- a/nep_auto/utils/runner.py +++ b/nep_auto/utils/runner.py @@ -0,0 +1,74 @@ +import subprocess +import os +import time +import logging + + +class CommandRunner: + def __init__(self, machine_config): + """ + :param machine_config: config/machine.yaml 中 'systems' -> 'current_system' 对应的内容 + """ + self.config = machine_config + self.logger = logging.getLogger("NEP_Auto") + self.mode = self.config.get("type", "local") # local 或 slurm + + def run(self, tool_name, cwd=".", wait=True, extra_args=""): + """ + 核心运行方法 + :param tool_name: machine.yaml 中 tools 下的键名 (如 'gpumd', 'vasp') + :param cwd: 执行命令的工作目录 + :param wait: 是否等待命令结束 (True: 阻塞, False: 后台运行) + :param extra_args: 附加在命令后的参数 + """ + # 1. 获取工具配置 + tool_conf = self.config.get("tools", {}).get(tool_name) + if not tool_conf: + self.logger.error(f"❌ 找不到工具配置: {tool_name}") + raise ValueError(f"Tool {tool_name} not defined in machine.yaml") + + cmd = tool_conf.get("command") + env_setup = tool_conf.get("env_setup", "") + + # 2. 组装命令 (Local 模式) + if self.mode == "local": + full_cmd = f"{cmd} {extra_args}" + + # 如果有环境加载脚本,用 && 连接 + if env_setup: + full_cmd = f"{env_setup} && {full_cmd}" + + self.logger.info(f"⚙️ [Local] Executing: {full_cmd}") + self.logger.info(f" 📂 Workdir: {cwd}") + + try: + # 使用 bash 执行以支持 source 命令 + process = subprocess.Popen( + full_cmd, + shell=True, + cwd=cwd, + executable="/bin/bash", + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True + ) + + if wait: + stdout, stderr = process.communicate() + if process.returncode != 0: + self.logger.error(f"❌ Execution failed (Code {process.returncode})") + self.logger.error(f"Stderr: {stderr}") + raise RuntimeError(f"Command failed: {full_cmd}") + return True + else: + return process # 返回进程对象供监控 + + except Exception as e: + self.logger.error(f"❌ Runner Error: {str(e)}") + raise + + # 3. Slurm 模式 (预留接口,暂未实现具体逻辑) + elif self.mode == "slurm": + self.logger.warning("⚠️ Slurm mode not fully implemented yet.") + # 这里未来会生成 sbatch 脚本并提交 + return False \ No newline at end of file