nep框架搭建

This commit is contained in:
2025-12-08 22:05:06 +08:00
parent 5057d18e98
commit cba2afb403
9 changed files with 498 additions and 5 deletions

View File

@@ -0,0 +1,105 @@
import shutil
import re
from pathlib import Path
from nep_auto.modules.base_module import BaseModule
class SelectModule(BaseModule):
def __init__(self, driver, iter_id):
super().__init__(driver, iter_id)
self.work_dir = self.iter_dir / "01.select"
self.md_dir = self.iter_dir / "00.md" / "md"
def get_work_dir(self):
return self.work_dir
def get_frame_count(self, xyz_file):
"""读取 xyz 文件帧数 (简单通过 grep 'Lattice' 计数,或用 ASE)"""
if not xyz_file.exists():
return 0
# 简单方法:读取文件统计 Lattice 出现的次数 (ExtXYZ 格式)
try:
with open(xyz_file, 'r') as f:
content = f.read()
return content.count("Lattice=")
except:
return 0
def run(self):
self.logger.info(f"🔍 [Select] Starting Active Learning Selection Iter {self.iter_id}...")
self.initialize()
# 准备数据
src_dump = self.md_dir / "dump.xyz"
train_xyz_prev = self.root / "00.data" / "train.xyz" # 或者是上一轮的 train
# 如果是 iter > 1train.xyz 应该是累积的。这里简化,先假设有一个参考的 train.xyz
# 必须文件dump.xyz, train.xyz, nep.txt
shutil.copy(src_dump, self.work_dir / "dump.xyz")
# 这里的 train.xyz 是给 neptrain_select_structs.py 用作参考的
if self.iter_id == 1:
# 第一轮可以用 data 里的初始文件,或者做一个空的
pass
else:
# 复制上一轮的 train.xyz
pass
# 复制 nep.txt
shutil.copy(self.md_dir / "nep.txt", self.work_dir / "nep.txt")
# 读取参数
cfg = self.config_param['params']['select']
target_min = cfg.get('target_min', 60)
target_max = cfg.get('target_max', 120)
threshold = cfg.get('init_threshold', 0.01)
kit_root = self.driver.config_param['env']['gpumdkit_root']
script = f"{kit_root}/Scripts/sample_structures/neptrain_select_structs.py"
# 循环筛选
max_attempts = 10
attempt = 0
while attempt < max_attempts:
self.logger.info(f" -> Attempt {attempt + 1}: Threshold = {threshold}")
# 构造命令: python script dump.xyz train.xyz nep.txt [options]
# 注意:如果你的脚本不支持命令行传参阈值,需要修改脚本或用 sed 修改
# 假设脚本已经被修改支持 --distance {threshold},或者我们用一种 hack 方式
# 既然原流程是交互式的,这里强烈建议你修改 neptrain_select_structs.py
# 让它支持命令行参数parser.add_argument('--distance', ...)
cmd_args = f"{script} dump.xyz train.xyz nep.txt --distance {threshold} --auto_confirm"
try:
self.runner.run("python_script", cwd=self.work_dir, extra_args=cmd_args)
except Exception as e:
self.logger.warning(f"Select script warning: {e}")
# 检查结果
selected_file = self.work_dir / "selected.xyz"
count = self.get_frame_count(selected_file)
self.logger.info(f" -> Selected {count} structures.")
if target_min <= count <= target_max:
self.logger.info("✅ Selection criteria met!")
break
elif count < target_min:
self.logger.info(" -> Too few, lowering threshold (-0.01)...")
threshold = threshold - 0.01
else:
self.logger.info(" -> Too many, raising threshold (+0.01)...")
threshold = threshold + 0.01
attempt += 1
if attempt >= max_attempts:
self.logger.warning("⚠️ Max attempts reached in selection. Proceeding with current best.")
self.check_done()
def check_done(self):
if (self.work_dir / "selected.xyz").exists():
return True
raise RuntimeError("Selection failed: selected.xyz not found")