148 lines
5.9 KiB
Python
148 lines
5.9 KiB
Python
import shutil
|
||
import subprocess
|
||
from pathlib import Path
|
||
from nep_auto.modules.base_module import BaseModule
|
||
|
||
|
||
class SelectModule(BaseModule):
|
||
def __init__(self, driver, iter_id):
|
||
super().__init__(driver, iter_id)
|
||
self.work_dir = self.iter_dir / "01.select"
|
||
self.md_dir = self.iter_dir / "00.md" / "md"
|
||
|
||
def get_work_dir(self):
|
||
return self.work_dir
|
||
|
||
def get_frame_count(self, xyz_file):
|
||
"""读取 xyz 文件帧数 (通过 grep 'Lattice' 计数)"""
|
||
if not xyz_file.exists():
|
||
return 0
|
||
try:
|
||
# 使用 grep -c 更快,避免 python 读取大文件内存溢出
|
||
result = subprocess.run(
|
||
f"grep -c 'Lattice' {xyz_file}",
|
||
shell=True, stdout=subprocess.PIPE, text=True
|
||
)
|
||
return int(result.stdout.strip())
|
||
except:
|
||
return 0
|
||
|
||
def run(self):
|
||
self.logger.info(f"🔍 [Select] Starting Active Learning Selection Iter {self.iter_id}...")
|
||
self.initialize()
|
||
|
||
# ----------------------------------------
|
||
# 1. 准备必要文件
|
||
# ----------------------------------------
|
||
# A. 待筛选数据 (从 MD 结果拿)
|
||
src_dump = self.md_dir / "dump.xyz"
|
||
if not src_dump.exists():
|
||
raise FileNotFoundError(f"MD dump missing: {src_dump}")
|
||
shutil.copy(src_dump, self.work_dir / "dump.xyz")
|
||
|
||
# B. 势函数 (从 MD 结果拿)
|
||
shutil.copy(self.md_dir / "nep.txt", self.work_dir / "nep.txt")
|
||
|
||
# C. 历史训练集 (用于对比)
|
||
# 逻辑:如果是第一轮,我们需要一个初始的 train.xyz (即使是空的或者是 model.xyz)
|
||
# gpumdkit 需要这个文件存在
|
||
target_train_xyz = self.work_dir / "train.xyz"
|
||
|
||
if self.iter_id == 1:
|
||
# 尝试从 data 目录拿初始训练集,如果没有,可以用 model.xyz 充数
|
||
init_train = self.root / "00.data" / "train.xyz"
|
||
if init_train.exists():
|
||
shutil.copy(init_train, target_train_xyz)
|
||
else:
|
||
# 如果实在没有,把初始结构当做 train.xyz,避免脚本报错
|
||
self.logger.warning("No initial train.xyz found, using model.xyz as placeholder.")
|
||
shutil.copy(self.md_dir / "model.xyz", target_train_xyz)
|
||
else:
|
||
# 使用上一轮累积的训练集
|
||
prev_train = self.root / f"iter_{self.iter_id - 1:03d}" / "03.train" / "train.xyz"
|
||
if prev_train.exists():
|
||
shutil.copy(prev_train, target_train_xyz)
|
||
else:
|
||
raise FileNotFoundError(f"Previous train.xyz missing: {prev_train}")
|
||
|
||
# ----------------------------------------
|
||
# 2. 循环筛选 (调整阈值)
|
||
# ----------------------------------------
|
||
cfg = self.config_param['params']['select']
|
||
target_min = cfg.get('target_min', 60)
|
||
target_max = cfg.get('target_max', 120)
|
||
threshold = cfg.get('init_threshold', 0.01)
|
||
|
||
max_attempts = 10
|
||
attempt = 0
|
||
|
||
# gpumdkit 命令 (假设 machine.yaml 里配好了 tool 叫 'gpumdkit')
|
||
# 如果是 local 模式,runner.run 实际上是执行 command。
|
||
# 但这里我们需要特殊的 input pipe,runner 的通用接口可能不够用。
|
||
# 既然我们明确是 local 环境且用 pipe,直接用 subprocess 最稳。
|
||
gpumdkit_cmd = self.machine_config['tools']['gpumdkit']['command'] # e.g. "gpumdkit.sh"
|
||
|
||
while attempt < max_attempts:
|
||
self.logger.info(f" -> Attempt {attempt + 1}: Threshold = {threshold:.5f}")
|
||
|
||
# 构造输入流字符串
|
||
# 对应你的流程: 203 -> file names -> 1 (distance mode) -> threshold
|
||
input_str = f"203\ndump.xyz train.xyz nep.txt\n1\n{threshold}\n"
|
||
|
||
# 构造完整命令: echo -e "..." | gpumdkit.sh
|
||
# 注意:python 的 input 参数直接传给 stdin,不需要用 echo |
|
||
|
||
try:
|
||
self.logger.debug(f" Input string: {repr(input_str)}")
|
||
|
||
process = subprocess.run(
|
||
gpumdkit_cmd,
|
||
input=input_str,
|
||
cwd=self.work_dir,
|
||
shell=True,
|
||
executable="/bin/bash",
|
||
stdout=subprocess.PIPE,
|
||
stderr=subprocess.PIPE,
|
||
text=True
|
||
)
|
||
|
||
# 记录输出以便 debug
|
||
# self.logger.debug(process.stdout)
|
||
|
||
if process.returncode != 0:
|
||
self.logger.error(f"gpumdkit execution failed: {process.stderr}")
|
||
raise RuntimeError("gpumdkit failed")
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"Execution error: {e}")
|
||
raise
|
||
|
||
# 检查 selected.xyz
|
||
selected_file = self.work_dir / "selected.xyz"
|
||
count = self.get_frame_count(selected_file)
|
||
self.logger.info(f" -> Selected {count} structures.")
|
||
|
||
if target_min <= count <= target_max:
|
||
self.logger.info(f"✅ Selection success! ({count} frames)")
|
||
break
|
||
elif count < target_min:
|
||
self.logger.info(" -> Too few, lowering threshold (x0.8)...")
|
||
threshold *= 0.8
|
||
else:
|
||
self.logger.info(" -> Too many, raising threshold (x1.2)...")
|
||
threshold *= 1.2
|
||
|
||
# 稍微清理一下生成的中间文件,防止下次干扰?
|
||
# selected.xyz 会被下次覆盖,所以不删也行。
|
||
|
||
attempt += 1
|
||
|
||
if attempt >= max_attempts:
|
||
self.logger.warning("⚠️ Max attempts reached. Proceeding with current best.")
|
||
|
||
self.check_done()
|
||
|
||
def check_done(self):
|
||
if (self.work_dir / "selected.xyz").exists():
|
||
return True
|
||
raise RuntimeError("Selection failed: selected.xyz not found") |