Files
NEP-auto/nep_auto/modules/m2_select.py
2025-12-08 22:34:02 +08:00

148 lines
5.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import shutil
import subprocess
from pathlib import Path
from nep_auto.modules.base_module import BaseModule
class SelectModule(BaseModule):
def __init__(self, driver, iter_id):
super().__init__(driver, iter_id)
self.work_dir = self.iter_dir / "01.select"
self.md_dir = self.iter_dir / "00.md" / "md"
def get_work_dir(self):
return self.work_dir
def get_frame_count(self, xyz_file):
"""读取 xyz 文件帧数 (通过 grep 'Lattice' 计数)"""
if not xyz_file.exists():
return 0
try:
# 使用 grep -c 更快,避免 python 读取大文件内存溢出
result = subprocess.run(
f"grep -c 'Lattice' {xyz_file}",
shell=True, stdout=subprocess.PIPE, text=True
)
return int(result.stdout.strip())
except:
return 0
def run(self):
self.logger.info(f"🔍 [Select] Starting Active Learning Selection Iter {self.iter_id}...")
self.initialize()
# ----------------------------------------
# 1. 准备必要文件
# ----------------------------------------
# A. 待筛选数据 (从 MD 结果拿)
src_dump = self.md_dir / "dump.xyz"
if not src_dump.exists():
raise FileNotFoundError(f"MD dump missing: {src_dump}")
shutil.copy(src_dump, self.work_dir / "dump.xyz")
# B. 势函数 (从 MD 结果拿)
shutil.copy(self.md_dir / "nep.txt", self.work_dir / "nep.txt")
# C. 历史训练集 (用于对比)
# 逻辑:如果是第一轮,我们需要一个初始的 train.xyz (即使是空的或者是 model.xyz)
# gpumdkit 需要这个文件存在
target_train_xyz = self.work_dir / "train.xyz"
if self.iter_id == 1:
# 尝试从 data 目录拿初始训练集,如果没有,可以用 model.xyz 充数
init_train = self.root / "00.data" / "train.xyz"
if init_train.exists():
shutil.copy(init_train, target_train_xyz)
else:
# 如果实在没有,把初始结构当做 train.xyz避免脚本报错
self.logger.warning("No initial train.xyz found, using model.xyz as placeholder.")
shutil.copy(self.md_dir / "model.xyz", target_train_xyz)
else:
# 使用上一轮累积的训练集
prev_train = self.root / f"iter_{self.iter_id - 1:03d}" / "03.train" / "train.xyz"
if prev_train.exists():
shutil.copy(prev_train, target_train_xyz)
else:
raise FileNotFoundError(f"Previous train.xyz missing: {prev_train}")
# ----------------------------------------
# 2. 循环筛选 (调整阈值)
# ----------------------------------------
cfg = self.config_param['params']['select']
target_min = cfg.get('target_min', 60)
target_max = cfg.get('target_max', 120)
threshold = cfg.get('init_threshold', 0.01)
max_attempts = 10
attempt = 0
# gpumdkit 命令 (假设 machine.yaml 里配好了 tool 叫 'gpumdkit')
# 如果是 local 模式runner.run 实际上是执行 command。
# 但这里我们需要特殊的 input piperunner 的通用接口可能不够用。
# 既然我们明确是 local 环境且用 pipe直接用 subprocess 最稳。
gpumdkit_cmd = self.machine_config['tools']['gpumdkit']['command'] # e.g. "gpumdkit.sh"
while attempt < max_attempts:
self.logger.info(f" -> Attempt {attempt + 1}: Threshold = {threshold:.5f}")
# 构造输入流字符串
# 对应你的流程: 203 -> file names -> 1 (distance mode) -> threshold
input_str = f"203\ndump.xyz train.xyz nep.txt\n1\n{threshold}\n"
# 构造完整命令: echo -e "..." | gpumdkit.sh
# 注意python 的 input 参数直接传给 stdin不需要用 echo |
try:
self.logger.debug(f" Input string: {repr(input_str)}")
process = subprocess.run(
gpumdkit_cmd,
input=input_str,
cwd=self.work_dir,
shell=True,
executable="/bin/bash",
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True
)
# 记录输出以便 debug
# self.logger.debug(process.stdout)
if process.returncode != 0:
self.logger.error(f"gpumdkit execution failed: {process.stderr}")
raise RuntimeError("gpumdkit failed")
except Exception as e:
self.logger.error(f"Execution error: {e}")
raise
# 检查 selected.xyz
selected_file = self.work_dir / "selected.xyz"
count = self.get_frame_count(selected_file)
self.logger.info(f" -> Selected {count} structures.")
if target_min <= count <= target_max:
self.logger.info(f"✅ Selection success! ({count} frames)")
break
elif count < target_min:
self.logger.info(" -> Too few, lowering threshold (x0.8)...")
threshold *= 0.8
else:
self.logger.info(" -> Too many, raising threshold (x1.2)...")
threshold *= 1.2
# 稍微清理一下生成的中间文件,防止下次干扰?
# selected.xyz 会被下次覆盖,所以不删也行。
attempt += 1
if attempt >= max_attempts:
self.logger.warning("⚠️ Max attempts reached. Proceeding with current best.")
self.check_done()
def check_done(self):
if (self.work_dir / "selected.xyz").exists():
return True
raise RuntimeError("Selection failed: selected.xyz not found")