From c1de59a7f297ce37e88400f22617332f9b2555ca Mon Sep 17 00:00:00 2001 From: koko <1429659362@qq.com> Date: Wed, 10 Dec 2025 23:52:03 +0800 Subject: [PATCH] =?UTF-8?q?nep=E6=A1=86=E6=9E=B6=E9=87=8D=E6=9E=84=2001.se?= =?UTF-8?q?lect=E4=BF=AE=E5=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/workflow.py | 123 +++++++++++++++++++++++++++++++----------------- 1 file changed, 81 insertions(+), 42 deletions(-) diff --git a/src/workflow.py b/src/workflow.py index d384b92..0fbbbeb 100644 --- a/src/workflow.py +++ b/src/workflow.py @@ -210,7 +210,7 @@ class Workflow: self.logger.info(f"=== Step: 01.select ({method}) ===") os.makedirs(step_dir, exist_ok=True) - # [新增] 提前建立 output 目录,用于存放 select.csv 和 select.png + # Output 目录 output_dir = os.path.join(iter_path, "05.output") os.makedirs(output_dir, exist_ok=True) @@ -249,22 +249,36 @@ class Workflow: params = step_conf.get('params', [0.01, 60, 120]) kit_path = self.machine.config['paths'].get('gpumdkit', 'gpumdkit.sh') - # === 分支 A: 按距离筛选 (Loop) === + # === 分支 A: 按距离筛选 (二分法 Binary Search) === if method == "distance": - threshold = params[0] target_min, target_max = params[1], params[2] - step_size = 0.001 + + # 定义二分查找的初始区间 + # lower_bound = 0.0 (极其宽松,选所有) + # upper_bound = 0.2 (通常足够大,甚至可以设大一点,视体系而定) + # 我们取 param[0] 作为初始猜测,但搜索范围设宽一点 + lower_bound = 0.0 + upper_bound = 0.2 + max_attempts = 15 success = False - # [新增] 用于记录筛选历史 - selection_log = [] # List of dicts + selection_log = [] - self.logger.info(f"Targeting {target_min}-{target_max} structures. Initial thr={threshold}") + # 用于记录“最佳失败结果” (如果最终没收敛,用这个) + best_result = None + min_dist_to_range = float('inf') # 距离目标区间的差距 + + self.logger.info(f"Targeting {target_min}-{target_max} structures using Binary Search.") for attempt in range(max_attempts): current_attempt = attempt + 1 - self.logger.info(f"--- Attempt {current_attempt}: Threshold {threshold:.4f} ---") + + # 二分取值 + threshold = (lower_bound + upper_bound) / 2.0 + + self.logger.info( + f"--- Attempt {current_attempt}: Threshold {threshold:.6f} (Range: {lower_bound:.4f}-{upper_bound:.4f}) ---") input_str = f"203\ndump.xyz train.xyz nep.txt\n1\n{threshold}" @@ -272,7 +286,7 @@ class Workflow: self.logger.error("gpumdkit execution failed.") break - # 检查 selected.xyz 数量 + # 检查数量 count = 0 if os.path.exists(os.path.join(step_dir, "selected.xyz")): try: @@ -281,76 +295,101 @@ class Workflow: count = int(count_out.decode().strip()) except: count = 0 - self.logger.info(f"Selected count: {count}") - # [新增] 记录本次结果 + # 记录日志 + status = "Fail" + if target_min <= count <= target_max: + status = "Success" + selection_log.append({ "attempt": current_attempt, "threshold": threshold, "count": count, - "result": "Success" if target_min <= count <= target_max else "Fail" + "result": status }) - # 判断逻辑 + # 记录最佳结果 (距离 target_min 或 target_max 最近的) + dist = 0 + if count < target_min: + dist = target_min - count + elif count > target_max: + dist = count - target_max + + if dist < min_dist_to_range: + min_dist_to_range = dist + # 备份当前的 selected.xyz 为 best_selected.xyz + shutil.copy(os.path.join(step_dir, "selected.xyz"), + os.path.join(step_dir, "best_selected.xyz")) + best_result = {"threshold": threshold, "count": count} + + # 逻辑判断 if target_min <= count <= target_max: - self.logger.info(f"Success! {count} structures selected.") + self.logger.info(f"Success! Count {count} is within range.") success = True break - elif count < target_min: - self.logger.info("Too few. Decreasing threshold (Loosening).") - threshold -= step_size - if threshold < 0: threshold = 0.0001 - else: - self.logger.info("Too many. Increasing threshold (Tightening).") - threshold += step_size - # [新增] 循环结束后,写入 CSV 日志 + # 二分调整逻辑: + # 假设: 阈值越小(Loose) -> 选的越多; 阈值越大(Tight) -> 选的越少 + if count < target_min: + # 选少了 -> 需要更宽松 -> 降低阈值 -> 往 [lower, current] 搜 + self.logger.info("Too few. Need Looser (Lower) threshold.") + upper_bound = threshold + else: + # 选多了 -> 需要更严格 -> 提高阈值 -> 往 [current, upper] 搜 + self.logger.info("Too many. Need Tighter (Higher) threshold.") + lower_bound = threshold + + # 极小区间保护:如果上下界太接近,直接退出,避免死循环 + if (upper_bound - lower_bound) < 1e-6: + self.logger.warning("Search interval too small. Stopping.") + break + + # 循环结束后的处理 + if not success and best_result: + self.logger.warning( + f"Could not strictly satisfy range. Using best result: {best_result['count']} (Thr={best_result['threshold']:.6f})") + # 恢复最佳文件 + shutil.move(os.path.join(step_dir, "best_selected.xyz"), + os.path.join(step_dir, "selected.xyz")) + + # 清理临时文件 + if os.path.exists(os.path.join(step_dir, "best_selected.xyz")): + os.remove(os.path.join(step_dir, "best_selected.xyz")) + + # 写CSV日志 csv_path = os.path.join(output_dir, "select_log.csv") try: with open(csv_path, 'w') as f: f.write("Attempt,Threshold,Count,Result\n") for entry in selection_log: f.write( - f"{entry['attempt']},{entry['threshold']:.4f},{entry['count']},{entry['result']}\n") - self.logger.info(f"Selection log saved to {csv_path}") - except Exception as e: - self.logger.error(f"Failed to write select_log.csv: {e}") + f"{entry['attempt']},{entry['threshold']:.6f},{entry['count']},{entry['result']}\n") + except: + pass - # [新增] 复制 select.png (如果生成了的话) + # 归档图片 src_png = os.path.join(step_dir, "select.png") if os.path.exists(src_png): shutil.copy(src_png, os.path.join(output_dir, "select.png")) - self.logger.info("Archived select.png to output.") - if success: - self.tracker.mark_done(task_id_select) - else: - self.logger.warning("Failed to reach target count. Using last result.") - self.tracker.mark_done(task_id_select) + self.tracker.mark_done(task_id_select) - # === 分支 B: 按个数/随机筛选 === + # === 分支 B: 随机筛选 (保持不变) === elif method == "random": min_n, max_n = params[1], params[2] input_str = f"203\ndump.xyz train.xyz nep.txt\n2\n{min_n} {max_n}" - self.logger.info(f"Random selection: {min_n}-{max_n}") if run_cmd_with_log(kit_path, step_dir, "select_exec.log", input_str=input_str): - # 随机筛选没有迭代过程,但也可以记录一下 with open(os.path.join(output_dir, "select_log.csv"), 'w') as f: f.write("Method,Min,Max,Result\n") f.write(f"Random,{min_n},{max_n},Executed\n") - - # 同样尝试复制图片 src_png = os.path.join(step_dir, "select.png") - if os.path.exists(src_png): - shutil.copy(src_png, os.path.join(output_dir, "select.png")) - + if os.path.exists(src_png): shutil.copy(src_png, os.path.join(output_dir, "select.png")) self.tracker.mark_done(task_id_select) else: self.logger.error("Random selection failed.") return - else: self.logger.info("Skipping Select (Already Done).") # ==========================