105 lines
4.0 KiB
Python
105 lines
4.0 KiB
Python
import shutil
|
||
import re
|
||
from pathlib import Path
|
||
from nep_auto.modules.base_module import BaseModule
|
||
|
||
|
||
class SelectModule(BaseModule):
|
||
def __init__(self, driver, iter_id):
|
||
super().__init__(driver, iter_id)
|
||
self.work_dir = self.iter_dir / "01.select"
|
||
self.md_dir = self.iter_dir / "00.md" / "md"
|
||
|
||
def get_work_dir(self):
|
||
return self.work_dir
|
||
|
||
def get_frame_count(self, xyz_file):
|
||
"""读取 xyz 文件帧数 (简单通过 grep 'Lattice' 计数,或用 ASE)"""
|
||
if not xyz_file.exists():
|
||
return 0
|
||
# 简单方法:读取文件统计 Lattice 出现的次数 (ExtXYZ 格式)
|
||
try:
|
||
with open(xyz_file, 'r') as f:
|
||
content = f.read()
|
||
return content.count("Lattice=")
|
||
except:
|
||
return 0
|
||
|
||
def run(self):
|
||
self.logger.info(f"🔍 [Select] Starting Active Learning Selection Iter {self.iter_id}...")
|
||
self.initialize()
|
||
|
||
# 准备数据
|
||
src_dump = self.md_dir / "dump.xyz"
|
||
train_xyz_prev = self.root / "00.data" / "train.xyz" # 或者是上一轮的 train
|
||
# 如果是 iter > 1,train.xyz 应该是累积的。这里简化,先假设有一个参考的 train.xyz
|
||
|
||
# 必须文件:dump.xyz, train.xyz, nep.txt
|
||
shutil.copy(src_dump, self.work_dir / "dump.xyz")
|
||
|
||
# 这里的 train.xyz 是给 neptrain_select_structs.py 用作参考的
|
||
if self.iter_id == 1:
|
||
# 第一轮可以用 data 里的初始文件,或者做一个空的
|
||
pass
|
||
else:
|
||
# 复制上一轮的 train.xyz
|
||
pass
|
||
|
||
# 复制 nep.txt
|
||
shutil.copy(self.md_dir / "nep.txt", self.work_dir / "nep.txt")
|
||
|
||
# 读取参数
|
||
cfg = self.config_param['params']['select']
|
||
target_min = cfg.get('target_min', 60)
|
||
target_max = cfg.get('target_max', 120)
|
||
threshold = cfg.get('init_threshold', 0.01)
|
||
|
||
kit_root = self.driver.config_param['env']['gpumdkit_root']
|
||
script = f"{kit_root}/Scripts/sample_structures/neptrain_select_structs.py"
|
||
|
||
# 循环筛选
|
||
max_attempts = 10
|
||
attempt = 0
|
||
|
||
while attempt < max_attempts:
|
||
self.logger.info(f" -> Attempt {attempt + 1}: Threshold = {threshold}")
|
||
|
||
# 构造命令: python script dump.xyz train.xyz nep.txt [options]
|
||
# 注意:如果你的脚本不支持命令行传参阈值,需要修改脚本或用 sed 修改
|
||
# 假设脚本已经被修改支持 --distance {threshold},或者我们用一种 hack 方式
|
||
# 既然原流程是交互式的,这里强烈建议你修改 neptrain_select_structs.py
|
||
# 让它支持命令行参数:parser.add_argument('--distance', ...)
|
||
|
||
cmd_args = f"{script} dump.xyz train.xyz nep.txt --distance {threshold} --auto_confirm"
|
||
|
||
try:
|
||
self.runner.run("python_script", cwd=self.work_dir, extra_args=cmd_args)
|
||
except Exception as e:
|
||
self.logger.warning(f"Select script warning: {e}")
|
||
|
||
# 检查结果
|
||
selected_file = self.work_dir / "selected.xyz"
|
||
count = self.get_frame_count(selected_file)
|
||
self.logger.info(f" -> Selected {count} structures.")
|
||
|
||
if target_min <= count <= target_max:
|
||
self.logger.info("✅ Selection criteria met!")
|
||
break
|
||
elif count < target_min:
|
||
self.logger.info(" -> Too few, lowering threshold (-0.01)...")
|
||
threshold = threshold - 0.01
|
||
else:
|
||
self.logger.info(" -> Too many, raising threshold (+0.01)...")
|
||
threshold = threshold + 0.01
|
||
|
||
attempt += 1
|
||
|
||
if attempt >= max_attempts:
|
||
self.logger.warning("⚠️ Max attempts reached in selection. Proceeding with current best.")
|
||
|
||
self.check_done()
|
||
|
||
def check_done(self):
|
||
if (self.work_dir / "selected.xyz").exists():
|
||
return True
|
||
raise RuntimeError("Selection failed: selected.xyz not found") |