import shutil import re from pathlib import Path from nep_auto.modules.base_module import BaseModule class SelectModule(BaseModule): def __init__(self, driver, iter_id): super().__init__(driver, iter_id) self.work_dir = self.iter_dir / "01.select" self.md_dir = self.iter_dir / "00.md" / "md" def get_work_dir(self): return self.work_dir def get_frame_count(self, xyz_file): """读取 xyz 文件帧数 (简单通过 grep 'Lattice' 计数,或用 ASE)""" if not xyz_file.exists(): return 0 # 简单方法:读取文件统计 Lattice 出现的次数 (ExtXYZ 格式) try: with open(xyz_file, 'r') as f: content = f.read() return content.count("Lattice=") except: return 0 def run(self): self.logger.info(f"🔍 [Select] Starting Active Learning Selection Iter {self.iter_id}...") self.initialize() # 准备数据 src_dump = self.md_dir / "dump.xyz" train_xyz_prev = self.root / "00.data" / "train.xyz" # 或者是上一轮的 train # 如果是 iter > 1,train.xyz 应该是累积的。这里简化,先假设有一个参考的 train.xyz # 必须文件:dump.xyz, train.xyz, nep.txt shutil.copy(src_dump, self.work_dir / "dump.xyz") # 这里的 train.xyz 是给 neptrain_select_structs.py 用作参考的 if self.iter_id == 1: # 第一轮可以用 data 里的初始文件,或者做一个空的 pass else: # 复制上一轮的 train.xyz pass # 复制 nep.txt shutil.copy(self.md_dir / "nep.txt", self.work_dir / "nep.txt") # 读取参数 cfg = self.config_param['params']['select'] target_min = cfg.get('target_min', 60) target_max = cfg.get('target_max', 120) threshold = cfg.get('init_threshold', 0.01) kit_root = self.driver.config_param['env']['gpumdkit_root'] script = f"{kit_root}/Scripts/sample_structures/neptrain_select_structs.py" # 循环筛选 max_attempts = 10 attempt = 0 while attempt < max_attempts: self.logger.info(f" -> Attempt {attempt + 1}: Threshold = {threshold}") # 构造命令: python script dump.xyz train.xyz nep.txt [options] # 注意:如果你的脚本不支持命令行传参阈值,需要修改脚本或用 sed 修改 # 假设脚本已经被修改支持 --distance {threshold},或者我们用一种 hack 方式 # 既然原流程是交互式的,这里强烈建议你修改 neptrain_select_structs.py # 让它支持命令行参数:parser.add_argument('--distance', ...) cmd_args = f"{script} dump.xyz train.xyz nep.txt --distance {threshold} --auto_confirm" try: self.runner.run("python_script", cwd=self.work_dir, extra_args=cmd_args) except Exception as e: self.logger.warning(f"Select script warning: {e}") # 检查结果 selected_file = self.work_dir / "selected.xyz" count = self.get_frame_count(selected_file) self.logger.info(f" -> Selected {count} structures.") if target_min <= count <= target_max: self.logger.info("✅ Selection criteria met!") break elif count < target_min: self.logger.info(" -> Too few, lowering threshold (-0.01)...") threshold = threshold - 0.01 else: self.logger.info(" -> Too many, raising threshold (+0.01)...") threshold = threshold + 0.01 attempt += 1 if attempt >= max_attempts: self.logger.warning("⚠️ Max attempts reached in selection. Proceeding with current best.") self.check_done() def check_done(self): if (self.work_dir / "selected.xyz").exists(): return True raise RuntimeError("Selection failed: selected.xyz not found")