import shutil from nep_auto.modules.base_module import BaseModule class TrainModule(BaseModule): def __init__(self, driver, iter_id): super().__init__(driver, iter_id) self.template_subdir = "03_train" self.work_dir = self.iter_dir / "03.train" def get_work_dir(self): return self.work_dir def run(self): self.logger.info(f"🧠 [Train] Starting Training Iter {self.iter_id}...") self.initialize() # 1. 准备 train.xyz # 逻辑:当前 train.xyz = 上一轮 train.xyz + 本轮 scf/NEP-dataset.xyz current_train_xyz = self.work_dir / "train.xyz" # 打开输出文件 with open(current_train_xyz, 'wb') as outfile: # A. 写入上一轮数据 (或初始数据) if self.iter_id == 1: # 第一轮,看是否有初始训练集,如果没有则只用本轮的 SCF 数据 # 这里假设 iter_000 是个虚拟的,或者直接去 00.data 里找 init_data = self.root / "00.data" / "train.xyz" # 预留位置 pass else: prev_train = self.root / f"iter_{self.iter_id - 1:03d}" / "03.train" / "train.xyz" if prev_train.exists(): with open(prev_train, 'rb') as infile: shutil.copyfileobj(infile, outfile) # B. 写入本轮新数据 new_data = self.iter_dir / "02.scf" / "NEP-dataset.xyz" if new_data.exists(): with open(new_data, 'rb') as infile: shutil.copyfileobj(infile, outfile) else: raise FileNotFoundError("New training data (NEP-dataset.xyz) missing!") # 2. 准备 nep.in self.copy_template("nep.in") # 3. 运行训练 self.logger.info(" -> Running NEP training...") self.runner.run("nep", cwd=self.work_dir) self.check_done() def check_done(self): if (self.work_dir / "nep.txt").exists(): self.logger.info("✅ Training finished.") return True raise RuntimeError("Training failed: nep.txt not generated")