nep框架重构 03.train

This commit is contained in:
2025-12-09 21:25:38 +08:00
parent f19d8ac4f0
commit 323790ee66

View File

@@ -441,4 +441,114 @@ class Workflow:
# 这里简单推断一下 # 这里简单推断一下
res_file = os.path.join(step_dir, "NEPdataset-multiple_frames", "NEP-dataset.xyz") res_file = os.path.join(step_dir, "NEPdataset-multiple_frames", "NEP-dataset.xyz")
if os.path.exists(res_file): if os.path.exists(res_file):
self.new_data_chunk = res_file self.new_data_chunk = res_file
# ==========================
# Step: 03.train (Training)
# ==========================
elif step_name == "03.train":
step_dir = os.path.join(iter_path, "03.train")
task_id_train = f"{iter_name}.03.train"
if not self.tracker.is_done(task_id_train):
self.logger.info("=== Step: 03.train (NEP) ===")
os.makedirs(step_dir, exist_ok=True)
# 1. 准备 train.xyz (合并数据)
# 逻辑Current Total Train = Previous Total Train + New Data Chunk
self.logger.info("Merging training data...")
train_xyz_path = os.path.join(step_dir, "train.xyz")
# 打开目标文件准备写入
with open(train_xyz_path, 'w') as outfile:
# A. 写入旧数据 (如果存在)
if os.path.exists(self.current_train_set):
self.logger.info(f"Appending previous data: {self.current_train_set}")
with open(self.current_train_set, 'r') as infile:
shutil.copyfileobj(infile, outfile)
# B. 写入新数据 (来自本轮 SCF)
# 尝试获取变量,如果变量丢失则尝试从路径推断
new_data = getattr(self, 'new_data_chunk', None)
if not new_data:
# 推断路径: iter_XX/02.scf/NEPdataset-multiple_frames/NEP-dataset.xyz
new_data = os.path.join(iter_path, "02.scf", "NEPdataset-multiple_frames",
"NEP-dataset.xyz")
if new_data and os.path.exists(new_data):
self.logger.info(f"Appending new data: {new_data}")
with open(new_data, 'r') as infile:
shutil.copyfileobj(infile, outfile)
else:
if iter_id == 0 and not os.path.exists(self.current_train_set):
self.logger.error("No training data available (neither previous nor new).")
return
else:
self.logger.warning("No new data found from SCF step. Training on old data only.")
# 更新全局变量指向最新的 train.xyz供下一轮使用
self.current_train_set = train_xyz_path
# 2. 准备 nep.in
template_nep_in = os.path.join(self.template_dir, "03.train", "nep.in")
if os.path.exists(template_nep_in):
shutil.copy(template_nep_in, os.path.join(step_dir, "nep.in"))
else:
self.logger.error(f"nep.in template missing: {template_nep_in}")
return
# 3. 执行训练
executor_name = step_conf.get('executor', 'nep_local')
# 获取nep命令比如 "nep" 或者 "/path/to/nep"
# 注意nep 命令通常不需要参数,它会自动读取 nep.in
self.logger.info(f"Starting NEP training using {executor_name}...")
# 这里的 log 文件叫 train_exec.log
if not run_cmd_with_log("nep", step_dir,
"train_exec.log"): # 假设 cmd 是 nep如果 machine.yaml 里有特殊定义请调整
self.logger.error("NEP training failed.")
return
# 检查是否生成了 nep.txt
if os.path.exists(os.path.join(step_dir, "nep.txt")):
self.logger.info("Training finished. nep.txt generated.")
# 更新全局势函数路径,供下一轮 MD 使用
self.current_nep_pot = os.path.join(step_dir, "nep.txt")
else:
self.logger.error("nep.txt not found after training.")
return
# 4. 后处理:绘图与归档
self.logger.info("Generating plots (gpumdkit.sh -plt train)...")
kit_path = self.machine.config['paths'].get('gpumdkit', 'gpumdkit.sh')
cmd_plt = f"{kit_path} -plt train"
run_cmd_with_log(cmd_plt, step_dir, "plot.log")
# 检查并移动图片
# gpumdkit 通常生成 loss.png, energy.png, force.png 等,或者你说的 train.png
# 我们创建一个专门的 output 目录存放这一轮的成果
output_dir = os.path.join(self.workspace, "05.output", iter_name)
os.makedirs(output_dir, exist_ok=True)
# 移动 loss.png / train.png 到 output
# 假设生成的文件名包含 png
for file in os.listdir(step_dir):
if file.endswith(".png"):
src_png = os.path.join(step_dir, file)
dst_png = os.path.join(output_dir, file) # 保持原名
shutil.copy(src_png, dst_png)
self.logger.info(f"Archived plot: {file}")
# 特别处理:你提到的 train.png
if os.path.exists(os.path.join(step_dir, "train.png")):
# 如果你需要重命名或者确保它存在
pass
self.tracker.mark_done(task_id_train)
else:
self.logger.info("Skipping Train (Already Done).")
# 恢复变量状态
self.current_nep_pot = os.path.join(step_dir, "nep.txt")
self.current_train_set = os.path.join(step_dir, "train.xyz")