nep框架重构 01.select修复

This commit is contained in:
2025-12-10 23:52:03 +08:00
parent 83cfe0c1b7
commit c1de59a7f2

View File

@@ -210,7 +210,7 @@ class Workflow:
self.logger.info(f"=== Step: 01.select ({method}) ===")
os.makedirs(step_dir, exist_ok=True)
# [新增] 提前建立 output 目录,用于存放 select.csv 和 select.png
# Output 目录
output_dir = os.path.join(iter_path, "05.output")
os.makedirs(output_dir, exist_ok=True)
@@ -249,22 +249,36 @@ class Workflow:
params = step_conf.get('params', [0.01, 60, 120])
kit_path = self.machine.config['paths'].get('gpumdkit', 'gpumdkit.sh')
# === 分支 A: 按距离筛选 (Loop) ===
# === 分支 A: 按距离筛选 (二分法 Binary Search) ===
if method == "distance":
threshold = params[0]
target_min, target_max = params[1], params[2]
step_size = 0.001
# 定义二分查找的初始区间
# lower_bound = 0.0 (极其宽松,选所有)
# upper_bound = 0.2 (通常足够大,甚至可以设大一点,视体系而定)
# 我们取 param[0] 作为初始猜测,但搜索范围设宽一点
lower_bound = 0.0
upper_bound = 0.2
max_attempts = 15
success = False
# [新增] 用于记录筛选历史
selection_log = [] # List of dicts
selection_log = []
self.logger.info(f"Targeting {target_min}-{target_max} structures. Initial thr={threshold}")
# 用于记录“最佳失败结果” (如果最终没收敛,用这个)
best_result = None
min_dist_to_range = float('inf') # 距离目标区间的差距
self.logger.info(f"Targeting {target_min}-{target_max} structures using Binary Search.")
for attempt in range(max_attempts):
current_attempt = attempt + 1
self.logger.info(f"--- Attempt {current_attempt}: Threshold {threshold:.4f} ---")
# 二分取值
threshold = (lower_bound + upper_bound) / 2.0
self.logger.info(
f"--- Attempt {current_attempt}: Threshold {threshold:.6f} (Range: {lower_bound:.4f}-{upper_bound:.4f}) ---")
input_str = f"203\ndump.xyz train.xyz nep.txt\n1\n{threshold}"
@@ -272,7 +286,7 @@ class Workflow:
self.logger.error("gpumdkit execution failed.")
break
# 检查 selected.xyz 数量
# 检查数量
count = 0
if os.path.exists(os.path.join(step_dir, "selected.xyz")):
try:
@@ -281,76 +295,101 @@ class Workflow:
count = int(count_out.decode().strip())
except:
count = 0
self.logger.info(f"Selected count: {count}")
# [新增] 记录本次结果
# 记录日志
status = "Fail"
if target_min <= count <= target_max:
status = "Success"
selection_log.append({
"attempt": current_attempt,
"threshold": threshold,
"count": count,
"result": "Success" if target_min <= count <= target_max else "Fail"
"result": status
})
# 判断逻辑
# 记录最佳结果 (距离 target_min 或 target_max 最近的)
dist = 0
if count < target_min:
dist = target_min - count
elif count > target_max:
dist = count - target_max
if dist < min_dist_to_range:
min_dist_to_range = dist
# 备份当前的 selected.xyz 为 best_selected.xyz
shutil.copy(os.path.join(step_dir, "selected.xyz"),
os.path.join(step_dir, "best_selected.xyz"))
best_result = {"threshold": threshold, "count": count}
# 逻辑判断
if target_min <= count <= target_max:
self.logger.info(f"Success! {count} structures selected.")
self.logger.info(f"Success! Count {count} is within range.")
success = True
break
elif count < target_min:
self.logger.info("Too few. Decreasing threshold (Loosening).")
threshold -= step_size
if threshold < 0: threshold = 0.0001
else:
self.logger.info("Too many. Increasing threshold (Tightening).")
threshold += step_size
# [新增] 循环结束后,写入 CSV 日志
# 二分调整逻辑:
# 假设: 阈值越小(Loose) -> 选的越多; 阈值越大(Tight) -> 选的越少
if count < target_min:
# 选少了 -> 需要更宽松 -> 降低阈值 -> 往 [lower, current] 搜
self.logger.info("Too few. Need Looser (Lower) threshold.")
upper_bound = threshold
else:
# 选多了 -> 需要更严格 -> 提高阈值 -> 往 [current, upper] 搜
self.logger.info("Too many. Need Tighter (Higher) threshold.")
lower_bound = threshold
# 极小区间保护:如果上下界太接近,直接退出,避免死循环
if (upper_bound - lower_bound) < 1e-6:
self.logger.warning("Search interval too small. Stopping.")
break
# 循环结束后的处理
if not success and best_result:
self.logger.warning(
f"Could not strictly satisfy range. Using best result: {best_result['count']} (Thr={best_result['threshold']:.6f})")
# 恢复最佳文件
shutil.move(os.path.join(step_dir, "best_selected.xyz"),
os.path.join(step_dir, "selected.xyz"))
# 清理临时文件
if os.path.exists(os.path.join(step_dir, "best_selected.xyz")):
os.remove(os.path.join(step_dir, "best_selected.xyz"))
# 写CSV日志
csv_path = os.path.join(output_dir, "select_log.csv")
try:
with open(csv_path, 'w') as f:
f.write("Attempt,Threshold,Count,Result\n")
for entry in selection_log:
f.write(
f"{entry['attempt']},{entry['threshold']:.4f},{entry['count']},{entry['result']}\n")
self.logger.info(f"Selection log saved to {csv_path}")
except Exception as e:
self.logger.error(f"Failed to write select_log.csv: {e}")
f"{entry['attempt']},{entry['threshold']:.6f},{entry['count']},{entry['result']}\n")
except:
pass
# [新增] 复制 select.png (如果生成了的话)
# 归档图片
src_png = os.path.join(step_dir, "select.png")
if os.path.exists(src_png):
shutil.copy(src_png, os.path.join(output_dir, "select.png"))
self.logger.info("Archived select.png to output.")
if success:
self.tracker.mark_done(task_id_select)
else:
self.logger.warning("Failed to reach target count. Using last result.")
self.tracker.mark_done(task_id_select)
self.tracker.mark_done(task_id_select)
# === 分支 B: 按个数/随机筛选 ===
# === 分支 B: 随机筛选 (保持不变) ===
elif method == "random":
min_n, max_n = params[1], params[2]
input_str = f"203\ndump.xyz train.xyz nep.txt\n2\n{min_n} {max_n}"
self.logger.info(f"Random selection: {min_n}-{max_n}")
if run_cmd_with_log(kit_path, step_dir, "select_exec.log", input_str=input_str):
# 随机筛选没有迭代过程,但也可以记录一下
with open(os.path.join(output_dir, "select_log.csv"), 'w') as f:
f.write("Method,Min,Max,Result\n")
f.write(f"Random,{min_n},{max_n},Executed\n")
# 同样尝试复制图片
src_png = os.path.join(step_dir, "select.png")
if os.path.exists(src_png):
shutil.copy(src_png, os.path.join(output_dir, "select.png"))
if os.path.exists(src_png): shutil.copy(src_png, os.path.join(output_dir, "select.png"))
self.tracker.mark_done(task_id_select)
else:
self.logger.error("Random selection failed.")
return
else:
self.logger.info("Skipping Select (Already Done).")
# ==========================