nep框架重构 01.select修复
This commit is contained in:
121
src/workflow.py
121
src/workflow.py
@@ -210,7 +210,7 @@ class Workflow:
|
|||||||
self.logger.info(f"=== Step: 01.select ({method}) ===")
|
self.logger.info(f"=== Step: 01.select ({method}) ===")
|
||||||
os.makedirs(step_dir, exist_ok=True)
|
os.makedirs(step_dir, exist_ok=True)
|
||||||
|
|
||||||
# [新增] 提前建立 output 目录,用于存放 select.csv 和 select.png
|
# Output 目录
|
||||||
output_dir = os.path.join(iter_path, "05.output")
|
output_dir = os.path.join(iter_path, "05.output")
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
@@ -249,22 +249,36 @@ class Workflow:
|
|||||||
params = step_conf.get('params', [0.01, 60, 120])
|
params = step_conf.get('params', [0.01, 60, 120])
|
||||||
kit_path = self.machine.config['paths'].get('gpumdkit', 'gpumdkit.sh')
|
kit_path = self.machine.config['paths'].get('gpumdkit', 'gpumdkit.sh')
|
||||||
|
|
||||||
# === 分支 A: 按距离筛选 (Loop) ===
|
# === 分支 A: 按距离筛选 (二分法 Binary Search) ===
|
||||||
if method == "distance":
|
if method == "distance":
|
||||||
threshold = params[0]
|
|
||||||
target_min, target_max = params[1], params[2]
|
target_min, target_max = params[1], params[2]
|
||||||
step_size = 0.001
|
|
||||||
|
# 定义二分查找的初始区间
|
||||||
|
# lower_bound = 0.0 (极其宽松,选所有)
|
||||||
|
# upper_bound = 0.2 (通常足够大,甚至可以设大一点,视体系而定)
|
||||||
|
# 我们取 param[0] 作为初始猜测,但搜索范围设宽一点
|
||||||
|
lower_bound = 0.0
|
||||||
|
upper_bound = 0.2
|
||||||
|
|
||||||
max_attempts = 15
|
max_attempts = 15
|
||||||
success = False
|
success = False
|
||||||
|
|
||||||
# [新增] 用于记录筛选历史
|
selection_log = []
|
||||||
selection_log = [] # List of dicts
|
|
||||||
|
|
||||||
self.logger.info(f"Targeting {target_min}-{target_max} structures. Initial thr={threshold}")
|
# 用于记录“最佳失败结果” (如果最终没收敛,用这个)
|
||||||
|
best_result = None
|
||||||
|
min_dist_to_range = float('inf') # 距离目标区间的差距
|
||||||
|
|
||||||
|
self.logger.info(f"Targeting {target_min}-{target_max} structures using Binary Search.")
|
||||||
|
|
||||||
for attempt in range(max_attempts):
|
for attempt in range(max_attempts):
|
||||||
current_attempt = attempt + 1
|
current_attempt = attempt + 1
|
||||||
self.logger.info(f"--- Attempt {current_attempt}: Threshold {threshold:.4f} ---")
|
|
||||||
|
# 二分取值
|
||||||
|
threshold = (lower_bound + upper_bound) / 2.0
|
||||||
|
|
||||||
|
self.logger.info(
|
||||||
|
f"--- Attempt {current_attempt}: Threshold {threshold:.6f} (Range: {lower_bound:.4f}-{upper_bound:.4f}) ---")
|
||||||
|
|
||||||
input_str = f"203\ndump.xyz train.xyz nep.txt\n1\n{threshold}"
|
input_str = f"203\ndump.xyz train.xyz nep.txt\n1\n{threshold}"
|
||||||
|
|
||||||
@@ -272,7 +286,7 @@ class Workflow:
|
|||||||
self.logger.error("gpumdkit execution failed.")
|
self.logger.error("gpumdkit execution failed.")
|
||||||
break
|
break
|
||||||
|
|
||||||
# 检查 selected.xyz 数量
|
# 检查数量
|
||||||
count = 0
|
count = 0
|
||||||
if os.path.exists(os.path.join(step_dir, "selected.xyz")):
|
if os.path.exists(os.path.join(step_dir, "selected.xyz")):
|
||||||
try:
|
try:
|
||||||
@@ -281,76 +295,101 @@ class Workflow:
|
|||||||
count = int(count_out.decode().strip())
|
count = int(count_out.decode().strip())
|
||||||
except:
|
except:
|
||||||
count = 0
|
count = 0
|
||||||
|
|
||||||
self.logger.info(f"Selected count: {count}")
|
self.logger.info(f"Selected count: {count}")
|
||||||
|
|
||||||
# [新增] 记录本次结果
|
# 记录日志
|
||||||
|
status = "Fail"
|
||||||
|
if target_min <= count <= target_max:
|
||||||
|
status = "Success"
|
||||||
|
|
||||||
selection_log.append({
|
selection_log.append({
|
||||||
"attempt": current_attempt,
|
"attempt": current_attempt,
|
||||||
"threshold": threshold,
|
"threshold": threshold,
|
||||||
"count": count,
|
"count": count,
|
||||||
"result": "Success" if target_min <= count <= target_max else "Fail"
|
"result": status
|
||||||
})
|
})
|
||||||
|
|
||||||
# 判断逻辑
|
# 记录最佳结果 (距离 target_min 或 target_max 最近的)
|
||||||
|
dist = 0
|
||||||
|
if count < target_min:
|
||||||
|
dist = target_min - count
|
||||||
|
elif count > target_max:
|
||||||
|
dist = count - target_max
|
||||||
|
|
||||||
|
if dist < min_dist_to_range:
|
||||||
|
min_dist_to_range = dist
|
||||||
|
# 备份当前的 selected.xyz 为 best_selected.xyz
|
||||||
|
shutil.copy(os.path.join(step_dir, "selected.xyz"),
|
||||||
|
os.path.join(step_dir, "best_selected.xyz"))
|
||||||
|
best_result = {"threshold": threshold, "count": count}
|
||||||
|
|
||||||
|
# 逻辑判断
|
||||||
if target_min <= count <= target_max:
|
if target_min <= count <= target_max:
|
||||||
self.logger.info(f"Success! {count} structures selected.")
|
self.logger.info(f"Success! Count {count} is within range.")
|
||||||
success = True
|
success = True
|
||||||
break
|
break
|
||||||
elif count < target_min:
|
|
||||||
self.logger.info("Too few. Decreasing threshold (Loosening).")
|
|
||||||
threshold -= step_size
|
|
||||||
if threshold < 0: threshold = 0.0001
|
|
||||||
else:
|
|
||||||
self.logger.info("Too many. Increasing threshold (Tightening).")
|
|
||||||
threshold += step_size
|
|
||||||
|
|
||||||
# [新增] 循环结束后,写入 CSV 日志
|
# 二分调整逻辑:
|
||||||
|
# 假设: 阈值越小(Loose) -> 选的越多; 阈值越大(Tight) -> 选的越少
|
||||||
|
if count < target_min:
|
||||||
|
# 选少了 -> 需要更宽松 -> 降低阈值 -> 往 [lower, current] 搜
|
||||||
|
self.logger.info("Too few. Need Looser (Lower) threshold.")
|
||||||
|
upper_bound = threshold
|
||||||
|
else:
|
||||||
|
# 选多了 -> 需要更严格 -> 提高阈值 -> 往 [current, upper] 搜
|
||||||
|
self.logger.info("Too many. Need Tighter (Higher) threshold.")
|
||||||
|
lower_bound = threshold
|
||||||
|
|
||||||
|
# 极小区间保护:如果上下界太接近,直接退出,避免死循环
|
||||||
|
if (upper_bound - lower_bound) < 1e-6:
|
||||||
|
self.logger.warning("Search interval too small. Stopping.")
|
||||||
|
break
|
||||||
|
|
||||||
|
# 循环结束后的处理
|
||||||
|
if not success and best_result:
|
||||||
|
self.logger.warning(
|
||||||
|
f"Could not strictly satisfy range. Using best result: {best_result['count']} (Thr={best_result['threshold']:.6f})")
|
||||||
|
# 恢复最佳文件
|
||||||
|
shutil.move(os.path.join(step_dir, "best_selected.xyz"),
|
||||||
|
os.path.join(step_dir, "selected.xyz"))
|
||||||
|
|
||||||
|
# 清理临时文件
|
||||||
|
if os.path.exists(os.path.join(step_dir, "best_selected.xyz")):
|
||||||
|
os.remove(os.path.join(step_dir, "best_selected.xyz"))
|
||||||
|
|
||||||
|
# 写CSV日志
|
||||||
csv_path = os.path.join(output_dir, "select_log.csv")
|
csv_path = os.path.join(output_dir, "select_log.csv")
|
||||||
try:
|
try:
|
||||||
with open(csv_path, 'w') as f:
|
with open(csv_path, 'w') as f:
|
||||||
f.write("Attempt,Threshold,Count,Result\n")
|
f.write("Attempt,Threshold,Count,Result\n")
|
||||||
for entry in selection_log:
|
for entry in selection_log:
|
||||||
f.write(
|
f.write(
|
||||||
f"{entry['attempt']},{entry['threshold']:.4f},{entry['count']},{entry['result']}\n")
|
f"{entry['attempt']},{entry['threshold']:.6f},{entry['count']},{entry['result']}\n")
|
||||||
self.logger.info(f"Selection log saved to {csv_path}")
|
except:
|
||||||
except Exception as e:
|
pass
|
||||||
self.logger.error(f"Failed to write select_log.csv: {e}")
|
|
||||||
|
|
||||||
# [新增] 复制 select.png (如果生成了的话)
|
# 归档图片
|
||||||
src_png = os.path.join(step_dir, "select.png")
|
src_png = os.path.join(step_dir, "select.png")
|
||||||
if os.path.exists(src_png):
|
if os.path.exists(src_png):
|
||||||
shutil.copy(src_png, os.path.join(output_dir, "select.png"))
|
shutil.copy(src_png, os.path.join(output_dir, "select.png"))
|
||||||
self.logger.info("Archived select.png to output.")
|
|
||||||
|
|
||||||
if success:
|
|
||||||
self.tracker.mark_done(task_id_select)
|
|
||||||
else:
|
|
||||||
self.logger.warning("Failed to reach target count. Using last result.")
|
|
||||||
self.tracker.mark_done(task_id_select)
|
self.tracker.mark_done(task_id_select)
|
||||||
|
|
||||||
# === 分支 B: 按个数/随机筛选 ===
|
# === 分支 B: 随机筛选 (保持不变) ===
|
||||||
elif method == "random":
|
elif method == "random":
|
||||||
min_n, max_n = params[1], params[2]
|
min_n, max_n = params[1], params[2]
|
||||||
input_str = f"203\ndump.xyz train.xyz nep.txt\n2\n{min_n} {max_n}"
|
input_str = f"203\ndump.xyz train.xyz nep.txt\n2\n{min_n} {max_n}"
|
||||||
|
|
||||||
self.logger.info(f"Random selection: {min_n}-{max_n}")
|
self.logger.info(f"Random selection: {min_n}-{max_n}")
|
||||||
if run_cmd_with_log(kit_path, step_dir, "select_exec.log", input_str=input_str):
|
if run_cmd_with_log(kit_path, step_dir, "select_exec.log", input_str=input_str):
|
||||||
# 随机筛选没有迭代过程,但也可以记录一下
|
|
||||||
with open(os.path.join(output_dir, "select_log.csv"), 'w') as f:
|
with open(os.path.join(output_dir, "select_log.csv"), 'w') as f:
|
||||||
f.write("Method,Min,Max,Result\n")
|
f.write("Method,Min,Max,Result\n")
|
||||||
f.write(f"Random,{min_n},{max_n},Executed\n")
|
f.write(f"Random,{min_n},{max_n},Executed\n")
|
||||||
|
|
||||||
# 同样尝试复制图片
|
|
||||||
src_png = os.path.join(step_dir, "select.png")
|
src_png = os.path.join(step_dir, "select.png")
|
||||||
if os.path.exists(src_png):
|
if os.path.exists(src_png): shutil.copy(src_png, os.path.join(output_dir, "select.png"))
|
||||||
shutil.copy(src_png, os.path.join(output_dir, "select.png"))
|
|
||||||
|
|
||||||
self.tracker.mark_done(task_id_select)
|
self.tracker.mark_done(task_id_select)
|
||||||
else:
|
else:
|
||||||
self.logger.error("Random selection failed.")
|
self.logger.error("Random selection failed.")
|
||||||
return
|
return
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.logger.info("Skipping Select (Already Done).")
|
self.logger.info("Skipping Select (Already Done).")
|
||||||
# ==========================
|
# ==========================
|
||||||
|
|||||||
Reference in New Issue
Block a user