nep框架重构 01.select
This commit is contained in:
@@ -239,4 +239,75 @@ class Workflow:
|
||||
threshold = params[0]
|
||||
target_min, target_max = params[1], params[2]
|
||||
step_size = 0.001
|
||||
max_attempts = 15
|
||||
success = False
|
||||
|
||||
self.logger.info(f"Targeting {target_min}-{target_max} structures. Initial thr={threshold}")
|
||||
|
||||
for attempt in range(max_attempts):
|
||||
self.logger.info(f"--- Attempt {attempt + 1}: Threshold {threshold:.4f} ---")
|
||||
|
||||
# 构造输入: 203 -> filenames -> 1 (distance) -> threshold
|
||||
# 格式: "203\ndump.xyz train.xyz nep.txt\n1\n0.01"
|
||||
input_str = f"203\ndump.xyz train.xyz nep.txt\n1\n{threshold}"
|
||||
|
||||
# 运行 gpumdkit (输出重定向到 append 模式的日志)
|
||||
if not run_cmd_with_log(kit_path, step_dir, "select_exec.log", input_str=input_str):
|
||||
self.logger.error("gpumdkit execution failed.")
|
||||
break # 致命错误直接退出
|
||||
|
||||
# 检查 selected.xyz 数量
|
||||
if not os.path.exists(os.path.join(step_dir, "selected.xyz")):
|
||||
self.logger.warning("selected.xyz not found. Retrying...")
|
||||
continue
|
||||
|
||||
try:
|
||||
# grep -c "Lat" selected.xyz
|
||||
count_out = subprocess.check_output('grep -c "Lat" selected.xyz', shell=True,
|
||||
cwd=step_dir)
|
||||
count = int(count_out.decode().strip())
|
||||
except:
|
||||
count = 0
|
||||
|
||||
self.logger.info(f"Selected count: {count}")
|
||||
|
||||
# 判断逻辑
|
||||
if target_min <= count <= target_max:
|
||||
self.logger.info(f"Success! {count} structures selected.")
|
||||
success = True
|
||||
break
|
||||
elif count < target_min:
|
||||
# 选少了 -> 阈值太严 -> 降低阈值? (通常距离阈值越低,选的越多?)
|
||||
# 假设:Threshold 代表“最小允许距离”。距离 > Thr 才选。
|
||||
# 那么 Thr 越小,条件越松,选的越多。
|
||||
self.logger.info("Too few. Decreasing threshold (Loosening).")
|
||||
threshold -= step_size
|
||||
if threshold < 0: threshold = 0.0001
|
||||
else: # count > target_max
|
||||
# 选多了 -> 提高阈值
|
||||
self.logger.info("Too many. Increasing threshold (Tightening).")
|
||||
threshold += step_size
|
||||
|
||||
if success:
|
||||
self.tracker.mark_done(task_id_select)
|
||||
else:
|
||||
self.logger.warning(
|
||||
"Failed to reach target count within max attempts. Using last result.")
|
||||
# 这里可以选择是否标记完成,或者报错暂停。目前标记完成以便继续。
|
||||
self.tracker.mark_done(task_id_select)
|
||||
|
||||
# === 分支 B: 按个数/随机筛选 (One-shot) ===
|
||||
elif method == "random":
|
||||
min_n, max_n = params[1], params[2]
|
||||
# 构造输入: 203 -> filenames -> 2 (number) -> min max
|
||||
input_str = f"203\ndump.xyz train.xyz nep.txt\n2\n{min_n} {max_n}"
|
||||
|
||||
self.logger.info(f"Random selection: {min_n}-{max_n}")
|
||||
if run_cmd_with_log(kit_path, step_dir, "select_exec.log", input_str=input_str):
|
||||
self.tracker.mark_done(task_id_select)
|
||||
else:
|
||||
self.logger.error("Random selection failed.")
|
||||
return
|
||||
|
||||
else:
|
||||
self.logger.info("Skipping Select (Already Done).")
|
||||
Reference in New Issue
Block a user