Files
NEP-auto/nep_auto/modules/m2_select.py
2025-12-08 22:05:06 +08:00

105 lines
4.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import shutil
import re
from pathlib import Path
from nep_auto.modules.base_module import BaseModule
class SelectModule(BaseModule):
def __init__(self, driver, iter_id):
super().__init__(driver, iter_id)
self.work_dir = self.iter_dir / "01.select"
self.md_dir = self.iter_dir / "00.md" / "md"
def get_work_dir(self):
return self.work_dir
def get_frame_count(self, xyz_file):
"""读取 xyz 文件帧数 (简单通过 grep 'Lattice' 计数,或用 ASE)"""
if not xyz_file.exists():
return 0
# 简单方法:读取文件统计 Lattice 出现的次数 (ExtXYZ 格式)
try:
with open(xyz_file, 'r') as f:
content = f.read()
return content.count("Lattice=")
except:
return 0
def run(self):
self.logger.info(f"🔍 [Select] Starting Active Learning Selection Iter {self.iter_id}...")
self.initialize()
# 准备数据
src_dump = self.md_dir / "dump.xyz"
train_xyz_prev = self.root / "00.data" / "train.xyz" # 或者是上一轮的 train
# 如果是 iter > 1train.xyz 应该是累积的。这里简化,先假设有一个参考的 train.xyz
# 必须文件dump.xyz, train.xyz, nep.txt
shutil.copy(src_dump, self.work_dir / "dump.xyz")
# 这里的 train.xyz 是给 neptrain_select_structs.py 用作参考的
if self.iter_id == 1:
# 第一轮可以用 data 里的初始文件,或者做一个空的
pass
else:
# 复制上一轮的 train.xyz
pass
# 复制 nep.txt
shutil.copy(self.md_dir / "nep.txt", self.work_dir / "nep.txt")
# 读取参数
cfg = self.config_param['params']['select']
target_min = cfg.get('target_min', 60)
target_max = cfg.get('target_max', 120)
threshold = cfg.get('init_threshold', 0.01)
kit_root = self.driver.config_param['env']['gpumdkit_root']
script = f"{kit_root}/Scripts/sample_structures/neptrain_select_structs.py"
# 循环筛选
max_attempts = 10
attempt = 0
while attempt < max_attempts:
self.logger.info(f" -> Attempt {attempt + 1}: Threshold = {threshold}")
# 构造命令: python script dump.xyz train.xyz nep.txt [options]
# 注意:如果你的脚本不支持命令行传参阈值,需要修改脚本或用 sed 修改
# 假设脚本已经被修改支持 --distance {threshold},或者我们用一种 hack 方式
# 既然原流程是交互式的,这里强烈建议你修改 neptrain_select_structs.py
# 让它支持命令行参数parser.add_argument('--distance', ...)
cmd_args = f"{script} dump.xyz train.xyz nep.txt --distance {threshold} --auto_confirm"
try:
self.runner.run("python_script", cwd=self.work_dir, extra_args=cmd_args)
except Exception as e:
self.logger.warning(f"Select script warning: {e}")
# 检查结果
selected_file = self.work_dir / "selected.xyz"
count = self.get_frame_count(selected_file)
self.logger.info(f" -> Selected {count} structures.")
if target_min <= count <= target_max:
self.logger.info("✅ Selection criteria met!")
break
elif count < target_min:
self.logger.info(" -> Too few, lowering threshold (-0.01)...")
threshold = threshold - 0.01
else:
self.logger.info(" -> Too many, raising threshold (+0.01)...")
threshold = threshold + 0.01
attempt += 1
if attempt >= max_attempts:
self.logger.warning("⚠️ Max attempts reached in selection. Proceeding with current best.")
self.check_done()
def check_done(self):
if (self.work_dir / "selected.xyz").exists():
return True
raise RuntimeError("Selection failed: selected.xyz not found")