NEP框架重构01select
This commit is contained in:
@@ -185,7 +185,58 @@ class Workflow:
|
||||
# Step: 01.select
|
||||
# ==========================
|
||||
elif step_name == "01.select":
|
||||
# 可以在这里也加上 StateTracker 逻辑
|
||||
pass
|
||||
step_dir = os.path.join(iter_path, "01.select")
|
||||
task_id_select = f"{iter_name}.01.select"
|
||||
|
||||
# ... (后续步骤类似,暂时省略)
|
||||
if not self.tracker.is_done(task_id_select):
|
||||
self.logger.info(f"=== Step: 01.select ({step_conf.get('method', 'distance')}) ===")
|
||||
os.makedirs(step_dir, exist_ok=True)
|
||||
|
||||
# 1. 准备 dump.xyz (软链接)
|
||||
# 优先使用上一这步记录的路径,如果没有则尝试推断
|
||||
dump_src = getattr(self, 'last_dump_path', None)
|
||||
if not dump_src:
|
||||
# 尝试推断:iter_XX/00.md/production/dump.xyz
|
||||
dump_src = os.path.join(iter_path, "00.md", "production", "dump.xyz")
|
||||
|
||||
if os.path.exists(dump_src):
|
||||
dst_dump = os.path.join(step_dir, "dump.xyz")
|
||||
if os.path.exists(dst_dump): os.remove(dst_dump)
|
||||
# 使用绝对路径建立软链
|
||||
os.symlink(os.path.abspath(dump_src), dst_dump)
|
||||
else:
|
||||
self.logger.error(f"Source dump.xyz not found: {dump_src}")
|
||||
return
|
||||
|
||||
# 2. 准备 nep.txt
|
||||
shutil.copy(self.current_nep_pot, os.path.join(step_dir, "nep.txt"))
|
||||
|
||||
# 3. 准备 train.xyz
|
||||
# 逻辑:第一轮用 model.xyz,之后用 self.current_train_set
|
||||
if iter_id == 0:
|
||||
# 第一轮:找同轮次 00.md 下的 model.xyz
|
||||
model_xyz_src = os.path.join(iter_path, "00.md", "model.xyz")
|
||||
if os.path.exists(model_xyz_src):
|
||||
shutil.copy(model_xyz_src, os.path.join(step_dir, "train.xyz"))
|
||||
else:
|
||||
self.logger.error("model.xyz not found for initial train.xyz")
|
||||
return
|
||||
else:
|
||||
# 后续轮次
|
||||
if os.path.exists(self.current_train_set):
|
||||
shutil.copy(self.current_train_set, os.path.join(step_dir, "train.xyz"))
|
||||
else:
|
||||
self.logger.error(f"Previous train set missing: {self.current_train_set}")
|
||||
return
|
||||
|
||||
# 4. 执行筛选逻辑
|
||||
method = step_conf.get('method', 'distance')
|
||||
params = step_conf.get('params', [0.01, 60, 120]) # [threshold, min, max]
|
||||
kit_path = self.machine.config['paths'].get('gpumdkit', 'gpumdkit.sh')
|
||||
|
||||
# === 分支 A: 按距离筛选 (Loop) ===
|
||||
if method == "distance":
|
||||
threshold = params[0]
|
||||
target_min, target_max = params[1], params[2]
|
||||
step_size = 0.001
|
||||
self.logger.info("Skipping Select (Already Done).")
|
||||
Reference in New Issue
Block a user