nep框架重构 03.train
This commit is contained in:
110
src/workflow.py
110
src/workflow.py
@@ -442,3 +442,113 @@ class Workflow:
|
|||||||
res_file = os.path.join(step_dir, "NEPdataset-multiple_frames", "NEP-dataset.xyz")
|
res_file = os.path.join(step_dir, "NEPdataset-multiple_frames", "NEP-dataset.xyz")
|
||||||
if os.path.exists(res_file):
|
if os.path.exists(res_file):
|
||||||
self.new_data_chunk = res_file
|
self.new_data_chunk = res_file
|
||||||
|
# ==========================
|
||||||
|
# Step: 03.train (Training)
|
||||||
|
# ==========================
|
||||||
|
elif step_name == "03.train":
|
||||||
|
step_dir = os.path.join(iter_path, "03.train")
|
||||||
|
task_id_train = f"{iter_name}.03.train"
|
||||||
|
|
||||||
|
if not self.tracker.is_done(task_id_train):
|
||||||
|
self.logger.info("=== Step: 03.train (NEP) ===")
|
||||||
|
os.makedirs(step_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 1. 准备 train.xyz (合并数据)
|
||||||
|
# 逻辑:Current Total Train = Previous Total Train + New Data Chunk
|
||||||
|
self.logger.info("Merging training data...")
|
||||||
|
|
||||||
|
train_xyz_path = os.path.join(step_dir, "train.xyz")
|
||||||
|
|
||||||
|
# 打开目标文件准备写入
|
||||||
|
with open(train_xyz_path, 'w') as outfile:
|
||||||
|
# A. 写入旧数据 (如果存在)
|
||||||
|
if os.path.exists(self.current_train_set):
|
||||||
|
self.logger.info(f"Appending previous data: {self.current_train_set}")
|
||||||
|
with open(self.current_train_set, 'r') as infile:
|
||||||
|
shutil.copyfileobj(infile, outfile)
|
||||||
|
|
||||||
|
# B. 写入新数据 (来自本轮 SCF)
|
||||||
|
# 尝试获取变量,如果变量丢失则尝试从路径推断
|
||||||
|
new_data = getattr(self, 'new_data_chunk', None)
|
||||||
|
if not new_data:
|
||||||
|
# 推断路径: iter_XX/02.scf/NEPdataset-multiple_frames/NEP-dataset.xyz
|
||||||
|
new_data = os.path.join(iter_path, "02.scf", "NEPdataset-multiple_frames",
|
||||||
|
"NEP-dataset.xyz")
|
||||||
|
|
||||||
|
if new_data and os.path.exists(new_data):
|
||||||
|
self.logger.info(f"Appending new data: {new_data}")
|
||||||
|
with open(new_data, 'r') as infile:
|
||||||
|
shutil.copyfileobj(infile, outfile)
|
||||||
|
else:
|
||||||
|
if iter_id == 0 and not os.path.exists(self.current_train_set):
|
||||||
|
self.logger.error("No training data available (neither previous nor new).")
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
self.logger.warning("No new data found from SCF step. Training on old data only.")
|
||||||
|
|
||||||
|
# 更新全局变量指向最新的 train.xyz,供下一轮使用
|
||||||
|
self.current_train_set = train_xyz_path
|
||||||
|
|
||||||
|
# 2. 准备 nep.in
|
||||||
|
template_nep_in = os.path.join(self.template_dir, "03.train", "nep.in")
|
||||||
|
if os.path.exists(template_nep_in):
|
||||||
|
shutil.copy(template_nep_in, os.path.join(step_dir, "nep.in"))
|
||||||
|
else:
|
||||||
|
self.logger.error(f"nep.in template missing: {template_nep_in}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 3. 执行训练
|
||||||
|
executor_name = step_conf.get('executor', 'nep_local')
|
||||||
|
# 获取nep命令,比如 "nep" 或者 "/path/to/nep"
|
||||||
|
# 注意:nep 命令通常不需要参数,它会自动读取 nep.in
|
||||||
|
|
||||||
|
self.logger.info(f"Starting NEP training using {executor_name}...")
|
||||||
|
# 这里的 log 文件叫 train_exec.log
|
||||||
|
if not run_cmd_with_log("nep", step_dir,
|
||||||
|
"train_exec.log"): # 假设 cmd 是 nep,如果 machine.yaml 里有特殊定义请调整
|
||||||
|
self.logger.error("NEP training failed.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 检查是否生成了 nep.txt
|
||||||
|
if os.path.exists(os.path.join(step_dir, "nep.txt")):
|
||||||
|
self.logger.info("Training finished. nep.txt generated.")
|
||||||
|
# 更新全局势函数路径,供下一轮 MD 使用
|
||||||
|
self.current_nep_pot = os.path.join(step_dir, "nep.txt")
|
||||||
|
else:
|
||||||
|
self.logger.error("nep.txt not found after training.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 4. 后处理:绘图与归档
|
||||||
|
self.logger.info("Generating plots (gpumdkit.sh -plt train)...")
|
||||||
|
kit_path = self.machine.config['paths'].get('gpumdkit', 'gpumdkit.sh')
|
||||||
|
cmd_plt = f"{kit_path} -plt train"
|
||||||
|
|
||||||
|
run_cmd_with_log(cmd_plt, step_dir, "plot.log")
|
||||||
|
|
||||||
|
# 检查并移动图片
|
||||||
|
# gpumdkit 通常生成 loss.png, energy.png, force.png 等,或者你说的 train.png
|
||||||
|
# 我们创建一个专门的 output 目录存放这一轮的成果
|
||||||
|
output_dir = os.path.join(self.workspace, "05.output", iter_name)
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 移动 loss.png / train.png 到 output
|
||||||
|
# 假设生成的文件名包含 png
|
||||||
|
for file in os.listdir(step_dir):
|
||||||
|
if file.endswith(".png"):
|
||||||
|
src_png = os.path.join(step_dir, file)
|
||||||
|
dst_png = os.path.join(output_dir, file) # 保持原名
|
||||||
|
shutil.copy(src_png, dst_png)
|
||||||
|
self.logger.info(f"Archived plot: {file}")
|
||||||
|
|
||||||
|
# 特别处理:你提到的 train.png
|
||||||
|
if os.path.exists(os.path.join(step_dir, "train.png")):
|
||||||
|
# 如果你需要重命名或者确保它存在
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.tracker.mark_done(task_id_train)
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.logger.info("Skipping Train (Already Done).")
|
||||||
|
# 恢复变量状态
|
||||||
|
self.current_nep_pot = os.path.join(step_dir, "nep.txt")
|
||||||
|
self.current_train_set = os.path.join(step_dir, "train.xyz")
|
||||||
|
|||||||
Reference in New Issue
Block a user