nep框架搭建
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from nep_auto.modules.base_module import BaseModule
|
||||
|
||||
|
||||
@@ -15,42 +16,53 @@ class TrainModule(BaseModule):
|
||||
self.logger.info(f"🧠 [Train] Starting Training Iter {self.iter_id}...")
|
||||
self.initialize()
|
||||
|
||||
# 1. 准备 train.xyz
|
||||
# 逻辑:当前 train.xyz = 上一轮 train.xyz + 本轮 scf/NEP-dataset.xyz
|
||||
current_train_xyz = self.work_dir / "train.xyz"
|
||||
# ----------------------------------------
|
||||
# 1. 准备 train.xyz (合并)
|
||||
# ----------------------------------------
|
||||
# 目标文件
|
||||
current_train = self.work_dir / "train.xyz"
|
||||
|
||||
# 打开输出文件
|
||||
with open(current_train_xyz, 'wb') as outfile:
|
||||
# A. 写入上一轮数据 (或初始数据)
|
||||
if self.iter_id == 1:
|
||||
# 第一轮,看是否有初始训练集,如果没有则只用本轮的 SCF 数据
|
||||
# 这里假设 iter_000 是个虚拟的,或者直接去 00.data 里找
|
||||
init_data = self.root / "00.data" / "train.xyz" # 预留位置
|
||||
pass
|
||||
else:
|
||||
prev_train = self.root / f"iter_{self.iter_id - 1:03d}" / "03.train" / "train.xyz"
|
||||
if prev_train.exists():
|
||||
with open(prev_train, 'rb') as infile:
|
||||
shutil.copyfileobj(infile, outfile)
|
||||
# 来源 1: 上一轮的 train.xyz (如果是第一轮,找初始数据)
|
||||
sources = []
|
||||
if self.iter_id == 1:
|
||||
init_data = self.root / "00.data" / "train.xyz"
|
||||
if init_data.exists():
|
||||
sources.append(init_data)
|
||||
else:
|
||||
prev_train = self.root / f"iter_{self.iter_id - 1:03d}" / "03.train" / "train.xyz"
|
||||
if prev_train.exists():
|
||||
sources.append(prev_train)
|
||||
|
||||
# B. 写入本轮新数据
|
||||
new_data = self.iter_dir / "02.scf" / "NEP-dataset.xyz"
|
||||
if new_data.exists():
|
||||
with open(new_data, 'rb') as infile:
|
||||
# 来源 2: 本轮新算的 SCF 数据
|
||||
new_data = self.iter_dir / "02.scf" / "NEP-dataset.xyz"
|
||||
if new_data.exists():
|
||||
sources.append(new_data)
|
||||
else:
|
||||
raise FileNotFoundError("New training data (NEP-dataset.xyz) missing!")
|
||||
|
||||
# 执行合并
|
||||
self.logger.info(f" -> Merging {len(sources)} datasets into train.xyz...")
|
||||
with open(current_train, 'wb') as outfile:
|
||||
for src in sources:
|
||||
with open(src, 'rb') as infile:
|
||||
shutil.copyfileobj(infile, outfile)
|
||||
else:
|
||||
raise FileNotFoundError("New training data (NEP-dataset.xyz) missing!")
|
||||
|
||||
# ----------------------------------------
|
||||
# 2. 准备 nep.in
|
||||
# ----------------------------------------
|
||||
self.copy_template("nep.in")
|
||||
|
||||
# 3. 运行训练
|
||||
# ----------------------------------------
|
||||
# 3. 运行训练 (调用 machine.yaml 里的 nep)
|
||||
# ----------------------------------------
|
||||
self.logger.info(" -> Running NEP training...")
|
||||
self.runner.run("nep", cwd=self.work_dir)
|
||||
|
||||
self.check_done()
|
||||
|
||||
def check_done(self):
|
||||
# 检查是否生成了 nep.txt
|
||||
# 通常还会检查 loss.out 是否收敛,或者生成了 virials.out 等
|
||||
if (self.work_dir / "nep.txt").exists():
|
||||
self.logger.info("✅ Training finished.")
|
||||
return True
|
||||
|
||||
Reference in New Issue
Block a user