NEP框架重构01select
This commit is contained in:
@@ -185,7 +185,58 @@ class Workflow:
|
|||||||
# Step: 01.select
|
# Step: 01.select
|
||||||
# ==========================
|
# ==========================
|
||||||
elif step_name == "01.select":
|
elif step_name == "01.select":
|
||||||
# 可以在这里也加上 StateTracker 逻辑
|
step_dir = os.path.join(iter_path, "01.select")
|
||||||
pass
|
task_id_select = f"{iter_name}.01.select"
|
||||||
|
|
||||||
# ... (后续步骤类似,暂时省略)
|
if not self.tracker.is_done(task_id_select):
|
||||||
|
self.logger.info(f"=== Step: 01.select ({step_conf.get('method', 'distance')}) ===")
|
||||||
|
os.makedirs(step_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 1. 准备 dump.xyz (软链接)
|
||||||
|
# 优先使用上一这步记录的路径,如果没有则尝试推断
|
||||||
|
dump_src = getattr(self, 'last_dump_path', None)
|
||||||
|
if not dump_src:
|
||||||
|
# 尝试推断:iter_XX/00.md/production/dump.xyz
|
||||||
|
dump_src = os.path.join(iter_path, "00.md", "production", "dump.xyz")
|
||||||
|
|
||||||
|
if os.path.exists(dump_src):
|
||||||
|
dst_dump = os.path.join(step_dir, "dump.xyz")
|
||||||
|
if os.path.exists(dst_dump): os.remove(dst_dump)
|
||||||
|
# 使用绝对路径建立软链
|
||||||
|
os.symlink(os.path.abspath(dump_src), dst_dump)
|
||||||
|
else:
|
||||||
|
self.logger.error(f"Source dump.xyz not found: {dump_src}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 2. 准备 nep.txt
|
||||||
|
shutil.copy(self.current_nep_pot, os.path.join(step_dir, "nep.txt"))
|
||||||
|
|
||||||
|
# 3. 准备 train.xyz
|
||||||
|
# 逻辑:第一轮用 model.xyz,之后用 self.current_train_set
|
||||||
|
if iter_id == 0:
|
||||||
|
# 第一轮:找同轮次 00.md 下的 model.xyz
|
||||||
|
model_xyz_src = os.path.join(iter_path, "00.md", "model.xyz")
|
||||||
|
if os.path.exists(model_xyz_src):
|
||||||
|
shutil.copy(model_xyz_src, os.path.join(step_dir, "train.xyz"))
|
||||||
|
else:
|
||||||
|
self.logger.error("model.xyz not found for initial train.xyz")
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
# 后续轮次
|
||||||
|
if os.path.exists(self.current_train_set):
|
||||||
|
shutil.copy(self.current_train_set, os.path.join(step_dir, "train.xyz"))
|
||||||
|
else:
|
||||||
|
self.logger.error(f"Previous train set missing: {self.current_train_set}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 4. 执行筛选逻辑
|
||||||
|
method = step_conf.get('method', 'distance')
|
||||||
|
params = step_conf.get('params', [0.01, 60, 120]) # [threshold, min, max]
|
||||||
|
kit_path = self.machine.config['paths'].get('gpumdkit', 'gpumdkit.sh')
|
||||||
|
|
||||||
|
# === 分支 A: 按距离筛选 (Loop) ===
|
||||||
|
if method == "distance":
|
||||||
|
threshold = params[0]
|
||||||
|
target_min, target_max = params[1], params[2]
|
||||||
|
step_size = 0.001
|
||||||
|
self.logger.info("Skipping Select (Already Done).")
|
||||||
Reference in New Issue
Block a user