From 323790ee6661a5d7105aaf96177ee3e9c42a7174 Mon Sep 17 00:00:00 2001 From: koko <1429659362@qq.com> Date: Tue, 9 Dec 2025 21:25:38 +0800 Subject: [PATCH] =?UTF-8?q?nep=E6=A1=86=E6=9E=B6=E9=87=8D=E6=9E=84=2003.tr?= =?UTF-8?q?ain?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/workflow.py | 112 +++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 111 insertions(+), 1 deletion(-) diff --git a/src/workflow.py b/src/workflow.py index 66798c3..f5fc37d 100644 --- a/src/workflow.py +++ b/src/workflow.py @@ -441,4 +441,114 @@ class Workflow: # 这里简单推断一下 res_file = os.path.join(step_dir, "NEPdataset-multiple_frames", "NEP-dataset.xyz") if os.path.exists(res_file): - self.new_data_chunk = res_file \ No newline at end of file + self.new_data_chunk = res_file + # ========================== + # Step: 03.train (Training) + # ========================== + elif step_name == "03.train": + step_dir = os.path.join(iter_path, "03.train") + task_id_train = f"{iter_name}.03.train" + + if not self.tracker.is_done(task_id_train): + self.logger.info("=== Step: 03.train (NEP) ===") + os.makedirs(step_dir, exist_ok=True) + + # 1. 准备 train.xyz (合并数据) + # 逻辑:Current Total Train = Previous Total Train + New Data Chunk + self.logger.info("Merging training data...") + + train_xyz_path = os.path.join(step_dir, "train.xyz") + + # 打开目标文件准备写入 + with open(train_xyz_path, 'w') as outfile: + # A. 写入旧数据 (如果存在) + if os.path.exists(self.current_train_set): + self.logger.info(f"Appending previous data: {self.current_train_set}") + with open(self.current_train_set, 'r') as infile: + shutil.copyfileobj(infile, outfile) + + # B. 写入新数据 (来自本轮 SCF) + # 尝试获取变量,如果变量丢失则尝试从路径推断 + new_data = getattr(self, 'new_data_chunk', None) + if not new_data: + # 推断路径: iter_XX/02.scf/NEPdataset-multiple_frames/NEP-dataset.xyz + new_data = os.path.join(iter_path, "02.scf", "NEPdataset-multiple_frames", + "NEP-dataset.xyz") + + if new_data and os.path.exists(new_data): + self.logger.info(f"Appending new data: {new_data}") + with open(new_data, 'r') as infile: + shutil.copyfileobj(infile, outfile) + else: + if iter_id == 0 and not os.path.exists(self.current_train_set): + self.logger.error("No training data available (neither previous nor new).") + return + else: + self.logger.warning("No new data found from SCF step. Training on old data only.") + + # 更新全局变量指向最新的 train.xyz,供下一轮使用 + self.current_train_set = train_xyz_path + + # 2. 准备 nep.in + template_nep_in = os.path.join(self.template_dir, "03.train", "nep.in") + if os.path.exists(template_nep_in): + shutil.copy(template_nep_in, os.path.join(step_dir, "nep.in")) + else: + self.logger.error(f"nep.in template missing: {template_nep_in}") + return + + # 3. 执行训练 + executor_name = step_conf.get('executor', 'nep_local') + # 获取nep命令,比如 "nep" 或者 "/path/to/nep" + # 注意:nep 命令通常不需要参数,它会自动读取 nep.in + + self.logger.info(f"Starting NEP training using {executor_name}...") + # 这里的 log 文件叫 train_exec.log + if not run_cmd_with_log("nep", step_dir, + "train_exec.log"): # 假设 cmd 是 nep,如果 machine.yaml 里有特殊定义请调整 + self.logger.error("NEP training failed.") + return + + # 检查是否生成了 nep.txt + if os.path.exists(os.path.join(step_dir, "nep.txt")): + self.logger.info("Training finished. nep.txt generated.") + # 更新全局势函数路径,供下一轮 MD 使用 + self.current_nep_pot = os.path.join(step_dir, "nep.txt") + else: + self.logger.error("nep.txt not found after training.") + return + + # 4. 后处理:绘图与归档 + self.logger.info("Generating plots (gpumdkit.sh -plt train)...") + kit_path = self.machine.config['paths'].get('gpumdkit', 'gpumdkit.sh') + cmd_plt = f"{kit_path} -plt train" + + run_cmd_with_log(cmd_plt, step_dir, "plot.log") + + # 检查并移动图片 + # gpumdkit 通常生成 loss.png, energy.png, force.png 等,或者你说的 train.png + # 我们创建一个专门的 output 目录存放这一轮的成果 + output_dir = os.path.join(self.workspace, "05.output", iter_name) + os.makedirs(output_dir, exist_ok=True) + + # 移动 loss.png / train.png 到 output + # 假设生成的文件名包含 png + for file in os.listdir(step_dir): + if file.endswith(".png"): + src_png = os.path.join(step_dir, file) + dst_png = os.path.join(output_dir, file) # 保持原名 + shutil.copy(src_png, dst_png) + self.logger.info(f"Archived plot: {file}") + + # 特别处理:你提到的 train.png + if os.path.exists(os.path.join(step_dir, "train.png")): + # 如果你需要重命名或者确保它存在 + pass + + self.tracker.mark_done(task_id_train) + + else: + self.logger.info("Skipping Train (Already Done).") + # 恢复变量状态 + self.current_nep_pot = os.path.join(step_dir, "nep.txt") + self.current_train_set = os.path.join(step_dir, "train.xyz")