Files
NEP-auto/nep_auto/modules/m4_train.py
2025-12-08 22:34:02 +08:00

69 lines
2.5 KiB
Python

import shutil
from pathlib import Path
from nep_auto.modules.base_module import BaseModule
class TrainModule(BaseModule):
def __init__(self, driver, iter_id):
super().__init__(driver, iter_id)
self.template_subdir = "03_train"
self.work_dir = self.iter_dir / "03.train"
def get_work_dir(self):
return self.work_dir
def run(self):
self.logger.info(f"🧠 [Train] Starting Training Iter {self.iter_id}...")
self.initialize()
# ----------------------------------------
# 1. 准备 train.xyz (合并)
# ----------------------------------------
# 目标文件
current_train = self.work_dir / "train.xyz"
# 来源 1: 上一轮的 train.xyz (如果是第一轮,找初始数据)
sources = []
if self.iter_id == 1:
init_data = self.root / "00.data" / "train.xyz"
if init_data.exists():
sources.append(init_data)
else:
prev_train = self.root / f"iter_{self.iter_id - 1:03d}" / "03.train" / "train.xyz"
if prev_train.exists():
sources.append(prev_train)
# 来源 2: 本轮新算的 SCF 数据
new_data = self.iter_dir / "02.scf" / "NEP-dataset.xyz"
if new_data.exists():
sources.append(new_data)
else:
raise FileNotFoundError("New training data (NEP-dataset.xyz) missing!")
# 执行合并
self.logger.info(f" -> Merging {len(sources)} datasets into train.xyz...")
with open(current_train, 'wb') as outfile:
for src in sources:
with open(src, 'rb') as infile:
shutil.copyfileobj(infile, outfile)
# ----------------------------------------
# 2. 准备 nep.in
# ----------------------------------------
self.copy_template("nep.in")
# ----------------------------------------
# 3. 运行训练 (调用 machine.yaml 里的 nep)
# ----------------------------------------
self.logger.info(" -> Running NEP training...")
self.runner.run("nep", cwd=self.work_dir)
self.check_done()
def check_done(self):
# 检查是否生成了 nep.txt
# 通常还会检查 loss.out 是否收敛,或者生成了 virials.out 等
if (self.work_dir / "nep.txt").exists():
self.logger.info("✅ Training finished.")
return True
raise RuntimeError("Training failed: nep.txt not generated")