import shutil from pathlib import Path from nep_auto.modules.base_module import BaseModule class TrainModule(BaseModule): def __init__(self, driver, iter_id): super().__init__(driver, iter_id) self.template_subdir = "03_train" self.work_dir = self.iter_dir / "03.train" def get_work_dir(self): return self.work_dir def run(self): self.logger.info(f"🧠 [Train] Starting Training Iter {self.iter_id}...") self.initialize() # ---------------------------------------- # 1. 准备 train.xyz (合并) # ---------------------------------------- # 目标文件 current_train = self.work_dir / "train.xyz" # 来源 1: 上一轮的 train.xyz (如果是第一轮,找初始数据) sources = [] if self.iter_id == 1: init_data = self.root / "00.data" / "train.xyz" if init_data.exists(): sources.append(init_data) else: prev_train = self.root / f"iter_{self.iter_id - 1:03d}" / "03.train" / "train.xyz" if prev_train.exists(): sources.append(prev_train) # 来源 2: 本轮新算的 SCF 数据 new_data = self.iter_dir / "02.scf" / "NEP-dataset.xyz" if new_data.exists(): sources.append(new_data) else: raise FileNotFoundError("New training data (NEP-dataset.xyz) missing!") # 执行合并 self.logger.info(f" -> Merging {len(sources)} datasets into train.xyz...") with open(current_train, 'wb') as outfile: for src in sources: with open(src, 'rb') as infile: shutil.copyfileobj(infile, outfile) # ---------------------------------------- # 2. 准备 nep.in # ---------------------------------------- self.copy_template("nep.in") # ---------------------------------------- # 3. 运行训练 (调用 machine.yaml 里的 nep) # ---------------------------------------- self.logger.info(" -> Running NEP training...") self.runner.run("nep", cwd=self.work_dir) self.check_done() def check_done(self): # 检查是否生成了 nep.txt # 通常还会检查 loss.out 是否收敛,或者生成了 virials.out 等 if (self.work_dir / "nep.txt").exists(): self.logger.info("✅ Training finished.") return True raise RuntimeError("Training failed: nep.txt not generated")