564 lines
32 KiB
Python
564 lines
32 KiB
Python
# src/workflow.py
|
||
import os
|
||
import shutil
|
||
import logging
|
||
import subprocess
|
||
|
||
from src.utils import load_yaml, run_cmd_with_log
|
||
from src.machine import MachineManager
|
||
from src.state import StateTracker # 新增
|
||
from src.steps import MDStep, SelectStep, SCFStep, TrainStep
|
||
|
||
|
||
class Workflow:
|
||
def __init__(self, root_dir):
|
||
self.root_dir = root_dir
|
||
self.param = load_yaml(os.path.join(root_dir, "config/param.yaml"))
|
||
self.machine = MachineManager(os.path.join(root_dir, "config/machine.yaml"))
|
||
|
||
self.workspace = os.path.join(root_dir, "workspace")
|
||
self.data_dir = os.path.join(root_dir, "data")
|
||
self.template_dir = os.path.join(root_dir, "template")
|
||
|
||
self.logger = logging.getLogger()
|
||
|
||
# 初始化状态追踪
|
||
os.makedirs(self.workspace, exist_ok=True)
|
||
self.tracker = StateTracker(self.workspace)
|
||
|
||
# 初始变量
|
||
self.current_nep_pot = os.path.join(self.data_dir, self.param['files']['initial_pot'])
|
||
self.current_train_set = os.path.join(self.workspace, "accumulated_train.xyz")
|
||
|
||
def run(self):
|
||
self.logger.info(f"Workflow Started: {self.param['project']}")
|
||
|
||
for iteration in self.param['iterations']:
|
||
iter_id = iteration['id']
|
||
iter_name = f"iter_{iter_id:02d}"
|
||
iter_path = os.path.join(self.workspace, iter_name)
|
||
self.logger.info(f"\n >>> Processing Iteration: {iter_id} <<<")
|
||
os.makedirs(iter_path, exist_ok=True)
|
||
|
||
for step_conf in iteration['steps']:
|
||
step_name = step_conf['name']
|
||
|
||
# ==========================
|
||
# Step: 00.md
|
||
# ==========================
|
||
if step_name == "00.md":
|
||
step_dir = os.path.join(iter_path, "00.md")
|
||
|
||
# 1. 初始化 model.xyz (仅做一次)
|
||
task_id_init = f"{iter_name}.00.md.init"
|
||
|
||
if iter_id == 0:
|
||
if not self.tracker.is_done(task_id_init):
|
||
os.makedirs(step_dir, exist_ok=True)
|
||
poscar_name = self.param['files']['poscar']
|
||
poscar_src = os.path.join(self.data_dir, poscar_name)
|
||
|
||
if os.path.exists(poscar_src):
|
||
shutil.copy(poscar_src, os.path.join(step_dir, poscar_name))
|
||
atom_labels = self.param['files'].get('label', '')
|
||
kit_path = self.machine.config['paths'].get('gpumdkit', 'gpumdkit.sh')
|
||
|
||
cmd = f"{kit_path} -addlabel {poscar_name} {atom_labels}"
|
||
self.logger.info(f"Initializing model.xyz...")
|
||
|
||
if run_cmd_with_log(cmd, step_dir, "init.log"):
|
||
self.tracker.mark_done(task_id_init)
|
||
else:
|
||
self.logger.error("Initialization failed. Check iter_00/00.md/init.log")
|
||
return
|
||
else:
|
||
self.logger.error("POSCAR missing.")
|
||
return
|
||
else:
|
||
self.logger.info("Skipping Init (Already Done).")
|
||
else:
|
||
# --- [新增逻辑] 后续轮次:从上一轮复制 model.xyz ---
|
||
|
||
# 1. 【核心修复】必须先确保目标文件夹存在!
|
||
if not os.path.exists(step_dir):
|
||
os.makedirs(step_dir, exist_ok=True)
|
||
|
||
# 2. 检查当前目录是否已有 model.xyz
|
||
if not os.path.exists(os.path.join(step_dir, "model.xyz")):
|
||
prev_iter_name = f"iter_{iter_id - 1:02d}"
|
||
prev_model_src = os.path.join(self.workspace, prev_iter_name, "00.md", "model.xyz")
|
||
|
||
if os.path.exists(prev_model_src):
|
||
self.logger.info(f"Copying model.xyz from {prev_iter_name}...")
|
||
shutil.copy(prev_model_src, os.path.join(step_dir, "model.xyz"))
|
||
else:
|
||
self.logger.error(f"Previous model.xyz not found: {prev_model_src}")
|
||
return
|
||
# 确保 gpumdkit 路径可用
|
||
kit_path = self.machine.config['paths'].get('gpumdkit', 'gpumdkit.sh')
|
||
|
||
# === Sub-task 1: Preheat ===
|
||
task_id_preheat = f"{iter_name}.00.md.preheat"
|
||
preheat_dir = os.path.join(step_dir, "preheat")
|
||
|
||
if not self.tracker.is_done(task_id_preheat):
|
||
self.logger.info(">>> Starting Preheat...")
|
||
os.makedirs(preheat_dir, exist_ok=True)
|
||
|
||
# 准备文件
|
||
shutil.copy(os.path.join(step_dir, "model.xyz"), os.path.join(preheat_dir, "model.xyz"))
|
||
shutil.copy(self.current_nep_pot, os.path.join(preheat_dir, "nep.txt"))
|
||
shutil.copy(os.path.join(self.template_dir, "00.md", "preheat", "run.in"),
|
||
os.path.join(preheat_dir, "run.in"))
|
||
|
||
# A. 运行 GPUMD
|
||
# 假设 gpumd 命令直接运行,无输入
|
||
if not run_cmd_with_log("gpumd", preheat_dir, "step_exec.log"):
|
||
self.logger.error("Preheat GPUMD failed.")
|
||
return
|
||
|
||
# B. 运行 采样 (201)
|
||
# [修正] 严格按照要求: "201\ndump.xyz uniform 4" (中间无额外换行)
|
||
input_str_201 = "201\ndump.xyz uniform 4"
|
||
self.logger.info(">>> Running Sampling (201)...")
|
||
|
||
if run_cmd_with_log(kit_path, preheat_dir, "step_exec.log", input_str=input_str_201):
|
||
if os.path.exists(os.path.join(preheat_dir, "sampled_structures.xyz")):
|
||
self.tracker.mark_done(task_id_preheat)
|
||
else:
|
||
self.logger.error("sampled_structures.xyz not generated.")
|
||
return
|
||
else:
|
||
self.logger.error("Sampling command failed.")
|
||
return
|
||
else:
|
||
self.logger.info("Skipping Preheat (Already Done).")
|
||
|
||
# === Sub-task 2: Production ===
|
||
task_id_prod = f"{iter_name}.00.md.production"
|
||
prod_dir = os.path.join(step_dir, "production")
|
||
|
||
if not self.tracker.is_done(task_id_prod):
|
||
self.logger.info(">>> Starting Production...")
|
||
os.makedirs(prod_dir, exist_ok=True)
|
||
|
||
# 1. 准备基础文件到 production 根目录
|
||
src_sample = os.path.abspath(os.path.join(preheat_dir, "sampled_structures.xyz"))
|
||
dst_sample = os.path.join(prod_dir, "sampled_structures.xyz")
|
||
if os.path.exists(dst_sample): os.remove(dst_sample)
|
||
os.symlink(src_sample, dst_sample)
|
||
|
||
shutil.copy(self.current_nep_pot, os.path.join(prod_dir, "nep.txt"))
|
||
shutil.copy(os.path.join(self.template_dir, "00.md", "production", "run.in"),
|
||
os.path.join(prod_dir, "run.in"))
|
||
|
||
# 2. 运行 302 (生成 md 文件夹, sample_* 文件夹, presub.sh)
|
||
# 302 通常会读取当前目录的 run.in 并在 md/ 下生成拆分后的 run_x.in
|
||
input_str_302 = "302"
|
||
if not run_cmd_with_log(kit_path, prod_dir, "step_exec.log", input_str=input_str_302):
|
||
self.logger.error("302 command failed.")
|
||
return
|
||
|
||
if not os.path.exists(os.path.join(prod_dir, "presub.sh")):
|
||
self.logger.error("presub.sh not found.")
|
||
return
|
||
|
||
# ---------------------------------------------------------
|
||
# [新增] 3. 补全文件:将 nep.txt 和 run.in 复制到 md 文件夹
|
||
# ---------------------------------------------------------
|
||
md_subdir = os.path.join(prod_dir, "md")
|
||
if os.path.exists(md_subdir):
|
||
self.logger.info("Copying nep.txt and run.in to 'md' folder...")
|
||
shutil.copy(os.path.join(prod_dir, "nep.txt"), os.path.join(md_subdir, "nep.txt"))
|
||
# 复制 run.in,虽然 302 可能已经生成了 run_1.in 等,但为了保险或用户习惯,我们也拷进去
|
||
shutil.copy(os.path.join(prod_dir, "run.in"), os.path.join(md_subdir, "run.in"))
|
||
shutil.copy(os.path.join(prod_dir, "run.in"), os.path.join(md_subdir, "run_1.in"))
|
||
shutil.copy(os.path.join(prod_dir, "run.in"), os.path.join(md_subdir, "run_2.in"))
|
||
shutil.copy(os.path.join(prod_dir, "run.in"), os.path.join(md_subdir, "run_3.in"))
|
||
shutil.copy(os.path.join(prod_dir, "run.in"), os.path.join(md_subdir, "run_4.in"))
|
||
else:
|
||
self.logger.error("'md' folder was not created by 302 command.")
|
||
return
|
||
|
||
# 4. 运行 presub.sh
|
||
os.chmod(os.path.join(prod_dir, "presub.sh"), 0o755)
|
||
self.logger.info(">>> Executing presub.sh...")
|
||
|
||
if not run_cmd_with_log("./presub.sh", prod_dir, "step_exec.log"):
|
||
self.logger.error("presub.sh execution failed.")
|
||
return
|
||
|
||
# 5. 合并 dump
|
||
self.logger.info("Merging dump files...")
|
||
run_cmd_with_log("cat sample_*/dump.xyz > dump.xyz", prod_dir, "step_exec.log")
|
||
|
||
self.last_dump_path = os.path.join(prod_dir, "dump.xyz")
|
||
self.tracker.mark_done(task_id_prod)
|
||
else:
|
||
self.logger.info("Skipping Production (Already Done).")
|
||
self.last_dump_path = os.path.join(prod_dir, "dump.xyz")
|
||
|
||
# ==========================
|
||
# Step: 01.select
|
||
# ==========================
|
||
elif step_name == "01.select":
|
||
step_dir = os.path.join(iter_path, "01.select")
|
||
task_id_select = f"{iter_name}.01.select"
|
||
|
||
if not self.tracker.is_done(task_id_select):
|
||
self.logger.info(f"=== Step: 01.select ({step_conf.get('method', 'distance')}) ===")
|
||
os.makedirs(step_dir, exist_ok=True)
|
||
|
||
# 1. 准备 dump.xyz (软链接)
|
||
# 优先使用上一这步记录的路径,如果没有则尝试推断
|
||
dump_src = getattr(self, 'last_dump_path', None)
|
||
if not dump_src:
|
||
# 尝试推断:iter_XX/00.md/production/dump.xyz
|
||
dump_src = os.path.join(iter_path, "00.md", "production", "dump.xyz")
|
||
|
||
if os.path.exists(dump_src):
|
||
dst_dump = os.path.join(step_dir, "dump.xyz")
|
||
if os.path.exists(dst_dump): os.remove(dst_dump)
|
||
# 使用绝对路径建立软链
|
||
os.symlink(os.path.abspath(dump_src), dst_dump)
|
||
else:
|
||
self.logger.error(f"Source dump.xyz not found: {dump_src}")
|
||
return
|
||
|
||
# 2. 准备 nep.txt
|
||
shutil.copy(self.current_nep_pot, os.path.join(step_dir, "nep.txt"))
|
||
|
||
# 3. 准备 train.xyz
|
||
# 逻辑:第一轮用 model.xyz,之后用 self.current_train_set
|
||
if iter_id == 0:
|
||
# 第一轮:找同轮次 00.md 下的 model.xyz
|
||
model_xyz_src = os.path.join(iter_path, "00.md", "model.xyz")
|
||
if os.path.exists(model_xyz_src):
|
||
shutil.copy(model_xyz_src, os.path.join(step_dir, "train.xyz"))
|
||
else:
|
||
self.logger.error("model.xyz not found for initial train.xyz")
|
||
return
|
||
else:
|
||
# 后续轮次
|
||
if os.path.exists(self.current_train_set):
|
||
shutil.copy(self.current_train_set, os.path.join(step_dir, "train.xyz"))
|
||
else:
|
||
self.logger.error(f"Previous train set missing: {self.current_train_set}")
|
||
return
|
||
|
||
# 4. 执行筛选逻辑
|
||
method = step_conf.get('method', 'distance')
|
||
params = step_conf.get('params', [0.01, 60, 120]) # [threshold, min, max]
|
||
kit_path = self.machine.config['paths'].get('gpumdkit', 'gpumdkit.sh')
|
||
|
||
# === 分支 A: 按距离筛选 (Loop) ===
|
||
if method == "distance":
|
||
threshold = params[0]
|
||
target_min, target_max = params[1], params[2]
|
||
step_size = 0.001
|
||
max_attempts = 15
|
||
success = False
|
||
|
||
self.logger.info(f"Targeting {target_min}-{target_max} structures. Initial thr={threshold}")
|
||
|
||
for attempt in range(max_attempts):
|
||
self.logger.info(f"--- Attempt {attempt + 1}: Threshold {threshold:.4f} ---")
|
||
|
||
# 构造输入: 203 -> filenames -> 1 (distance) -> threshold
|
||
# 格式: "203\ndump.xyz train.xyz nep.txt\n1\n0.01"
|
||
input_str = f"203\ndump.xyz train.xyz nep.txt\n1\n{threshold}"
|
||
|
||
# 运行 gpumdkit (输出重定向到 append 模式的日志)
|
||
if not run_cmd_with_log(kit_path, step_dir, "select_exec.log", input_str=input_str):
|
||
self.logger.error("gpumdkit execution failed.")
|
||
break # 致命错误直接退出
|
||
|
||
# 检查 selected.xyz 数量
|
||
if not os.path.exists(os.path.join(step_dir, "selected.xyz")):
|
||
self.logger.warning("selected.xyz not found. Retrying...")
|
||
continue
|
||
|
||
try:
|
||
# grep -c "Lat" selected.xyz
|
||
count_out = subprocess.check_output('grep -c "Lat" selected.xyz', shell=True,
|
||
cwd=step_dir)
|
||
count = int(count_out.decode().strip())
|
||
except:
|
||
count = 0
|
||
|
||
self.logger.info(f"Selected count: {count}")
|
||
|
||
# 判断逻辑
|
||
if target_min <= count <= target_max:
|
||
self.logger.info(f"Success! {count} structures selected.")
|
||
success = True
|
||
break
|
||
elif count < target_min:
|
||
# 选少了 -> 阈值太严 -> 降低阈值? (通常距离阈值越低,选的越多?)
|
||
# 假设:Threshold 代表“最小允许距离”。距离 > Thr 才选。
|
||
# 那么 Thr 越小,条件越松,选的越多。
|
||
self.logger.info("Too few. Decreasing threshold (Loosening).")
|
||
threshold -= step_size
|
||
if threshold < 0: threshold = 0.0001
|
||
else: # count > target_max
|
||
# 选多了 -> 提高阈值
|
||
self.logger.info("Too many. Increasing threshold (Tightening).")
|
||
threshold += step_size
|
||
|
||
if success:
|
||
self.tracker.mark_done(task_id_select)
|
||
else:
|
||
self.logger.warning(
|
||
"Failed to reach target count within max attempts. Using last result.")
|
||
# 这里可以选择是否标记完成,或者报错暂停。目前标记完成以便继续。
|
||
self.tracker.mark_done(task_id_select)
|
||
|
||
# === 分支 B: 按个数/随机筛选 (One-shot) ===
|
||
elif method == "random":
|
||
min_n, max_n = params[1], params[2]
|
||
# 构造输入: 203 -> filenames -> 2 (number) -> min max
|
||
input_str = f"203\ndump.xyz train.xyz nep.txt\n2\n{min_n} {max_n}"
|
||
|
||
self.logger.info(f"Random selection: {min_n}-{max_n}")
|
||
if run_cmd_with_log(kit_path, step_dir, "select_exec.log", input_str=input_str):
|
||
self.tracker.mark_done(task_id_select)
|
||
else:
|
||
self.logger.error("Random selection failed.")
|
||
return
|
||
|
||
else:
|
||
self.logger.info("Skipping Select (Already Done).")
|
||
# ==========================
|
||
# Step: 02.scf (VASP Calculation)
|
||
# ==========================
|
||
elif step_name == "02.scf":
|
||
step_dir = os.path.join(iter_path, "02.scf")
|
||
task_id_scf = f"{iter_name}.02.scf"
|
||
|
||
if not self.tracker.is_done(task_id_scf):
|
||
self.logger.info("=== Step: 02.scf (VASP) ===")
|
||
os.makedirs(step_dir, exist_ok=True)
|
||
|
||
# 1. 准备 selected.xyz
|
||
# 尝试从同轮次的 01.select 获取
|
||
select_step_dir = os.path.join(iter_path, "01.select")
|
||
src_selected = os.path.join(select_step_dir, "selected.xyz")
|
||
if not os.path.exists(src_selected):
|
||
self.logger.error(f"selected.xyz not found in {select_step_dir}")
|
||
return
|
||
|
||
dst_selected = os.path.join(step_dir, "selected.xyz")
|
||
if os.path.exists(dst_selected): os.remove(dst_selected)
|
||
os.symlink(os.path.abspath(src_selected), dst_selected)
|
||
|
||
# 2. 运行 301 拆分结构
|
||
# 命令: echo -e "301\niter" | gpumdkit.sh
|
||
# 这会生成 iterX_1, iterX_2... 和 fp 文件夹
|
||
kit_path = self.machine.config['paths'].get('gpumdkit', 'gpumdkit.sh')
|
||
input_str_301 = "301\niter" # 这里 "iter" 是文件夹前缀名,gpumdkit 会自动加数字
|
||
|
||
self.logger.info("Splitting structures (301)...")
|
||
if not run_cmd_with_log(kit_path, step_dir, "scf_setup.log", input_str=input_str_301):
|
||
self.logger.error("301 command failed.")
|
||
return
|
||
|
||
# 3. 准备 VASP 输入文件到 'fp' 文件夹
|
||
# gpumdkit 生成的 fp 文件夹通常存放公共文件,子文件夹会软链过去
|
||
fp_dir = os.path.join(step_dir, "fp")
|
||
if not os.path.exists(fp_dir):
|
||
self.logger.error("'fp' directory was not created by 301.")
|
||
return
|
||
|
||
self.logger.info("Distributing VASP inputs to 'fp' folder...")
|
||
|
||
# A. POTCAR (来自 Data)
|
||
potcar_src = os.path.join(self.data_dir, self.param['files']['potcar'])
|
||
if os.path.exists(potcar_src):
|
||
shutil.copy(potcar_src, os.path.join(fp_dir, "POTCAR"))
|
||
else:
|
||
self.logger.error(f"POTCAR missing: {potcar_src}")
|
||
return
|
||
|
||
# B. INCAR (来自 Template)
|
||
# Template 路径: template/02.scf/INCAR
|
||
incar_src = os.path.join(self.template_dir, "02.scf", "INCAR")
|
||
if os.path.exists(incar_src):
|
||
shutil.copy(incar_src, os.path.join(fp_dir, "INCAR"))
|
||
else:
|
||
self.logger.error(f"INCAR missing in template: {incar_src}")
|
||
return
|
||
|
||
# C. KPOINTS (来自 Template, 可选)
|
||
kpoints_src = os.path.join(self.template_dir, "02.scf", "KPOINTS")
|
||
if os.path.exists(kpoints_src):
|
||
shutil.copy(kpoints_src, os.path.join(fp_dir, "KPOINTS"))
|
||
else:
|
||
self.logger.info("KPOINTS not found in template, assuming KSPACING in INCAR.")
|
||
|
||
# 4. 生成并提交计算任务
|
||
# 这里我们不理会 gpumdkit 生成的 presub.sh,而是根据 machine.yaml 生成自己的
|
||
executor_name = step_conf.get('executor', 'vasp_gpu') # 默认用 cpu
|
||
|
||
# 获取执行命令 (例如 "mpirun -np 32 vasp_std")
|
||
# 这里的逻辑需要调用 machine 模块的一个新功能:批量生成提交脚本
|
||
# 但为了简化,我们在 Local 模式下生成一个遍历脚本
|
||
|
||
self.logger.info(f"Generating batch submission script for {executor_name}...")
|
||
|
||
# 读取 machine 配置里的命令
|
||
exec_conf = self.machine.config['executors'].get(executor_name, {})
|
||
vasp_cmd = exec_conf.get('cmd', 'mpirun -np 1 vasp_std') # 默认值
|
||
|
||
# 生成 run_vasp.sh
|
||
run_script_path = os.path.join(step_dir, "run_vasp.sh")
|
||
with open(run_script_path, 'w') as f:
|
||
f.write("#!/bin/bash\n")
|
||
# 遍历 iter* 目录
|
||
f.write(f"for dir in iter*_*; do\n")
|
||
f.write(f" if [ -d \"$dir\" ]; then\n")
|
||
f.write(f" echo \"Running VASP in $dir ...\"\n")
|
||
f.write(f" cd $dir\n")
|
||
# 写入具体的 VASP 执行命令
|
||
f.write(f" {vasp_cmd} > vasp.log 2>&1\n") # 重定向日志
|
||
f.write(f" cd ..\n")
|
||
f.write(f" fi\n")
|
||
f.write(f"done\n")
|
||
|
||
os.chmod(run_script_path, 0o755)
|
||
|
||
# 执行 VASP 计算
|
||
# 注意:如果是在 Slurm 上,这里应该提交 run_vasp.sh,并使用 Job ID 等待
|
||
# 目前 Local 模式直接运行
|
||
self.logger.info(">>> Executing VASP batch calculations (this may take time)...")
|
||
if not run_cmd_with_log("./run_vasp.sh", step_dir, "scf_exec.log"):
|
||
self.logger.error("VASP batch execution failed.")
|
||
return
|
||
|
||
# 5. 结果收集 (out2xyz)
|
||
self.logger.info("Collecting results (out2xyz)...")
|
||
cmd_collect = f"{kit_path} -out2xyz ."
|
||
if run_cmd_with_log(cmd_collect, step_dir, "scf_collect.log"):
|
||
# 检查结果
|
||
res_dir = os.path.join(step_dir, "NEPdataset-multiple_frames")
|
||
res_file = os.path.join(res_dir, "NEP-dataset.xyz")
|
||
|
||
if os.path.exists(res_file):
|
||
self.logger.info(f"VASP data collected: {res_file}")
|
||
# 保存这个路径供 Train 使用
|
||
self.new_data_chunk = res_file
|
||
self.tracker.mark_done(task_id_scf)
|
||
else:
|
||
self.logger.error("NEP-dataset.xyz not found after collection.")
|
||
else:
|
||
self.logger.error("out2xyz failed.")
|
||
|
||
else:
|
||
self.logger.info("Skipping SCF (Already Done).")
|
||
# 即使跳过,也要尝试恢复 self.new_data_chunk 变量,防止 Train 找不到数据
|
||
# 这里简单推断一下
|
||
res_file = os.path.join(step_dir, "NEPdataset-multiple_frames", "NEP-dataset.xyz")
|
||
if os.path.exists(res_file):
|
||
self.new_data_chunk = res_file
|
||
# ==========================
|
||
# Step: 03.train (Training)
|
||
# ==========================
|
||
elif step_name == "03.train":
|
||
step_dir = os.path.join(iter_path, "03.train")
|
||
task_id_train = f"{iter_name}.03.train"
|
||
|
||
if not self.tracker.is_done(task_id_train):
|
||
self.logger.info("=== Step: 03.train (NEP) ===")
|
||
os.makedirs(step_dir, exist_ok=True)
|
||
|
||
# 1. 准备 train.xyz (合并数据)
|
||
# 逻辑:Current Total Train = Previous Total Train + New Data Chunk
|
||
self.logger.info("Merging training data...")
|
||
|
||
train_xyz_path = os.path.join(step_dir, "train.xyz")
|
||
|
||
# 打开目标文件准备写入
|
||
with open(train_xyz_path, 'w') as outfile:
|
||
# A. 写入旧数据 (如果存在)
|
||
if os.path.exists(self.current_train_set):
|
||
self.logger.info(f"Appending previous data: {self.current_train_set}")
|
||
with open(self.current_train_set, 'r') as infile:
|
||
shutil.copyfileobj(infile, outfile)
|
||
|
||
# B. 写入新数据 (来自本轮 SCF)
|
||
# 尝试获取变量,如果变量丢失则尝试从路径推断
|
||
new_data = getattr(self, 'new_data_chunk', None)
|
||
if not new_data:
|
||
# 推断路径: iter_XX/02.scf/NEPdataset-multiple_frames/NEP-dataset.xyz
|
||
new_data = os.path.join(iter_path, "02.scf", "NEPdataset-multiple_frames",
|
||
"NEP-dataset.xyz")
|
||
|
||
if new_data and os.path.exists(new_data):
|
||
self.logger.info(f"Appending new data: {new_data}")
|
||
with open(new_data, 'r') as infile:
|
||
shutil.copyfileobj(infile, outfile)
|
||
else:
|
||
if iter_id == 0 and not os.path.exists(self.current_train_set):
|
||
self.logger.error("No training data available (neither previous nor new).")
|
||
return
|
||
else:
|
||
self.logger.warning("No new data found from SCF step. Training on old data only.")
|
||
|
||
# 更新全局变量指向最新的 train.xyz,供下一轮使用
|
||
self.current_train_set = train_xyz_path
|
||
|
||
# 2. 准备 nep.in
|
||
template_nep_in = os.path.join(self.template_dir, "03.train", "nep.in")
|
||
if os.path.exists(template_nep_in):
|
||
shutil.copy(template_nep_in, os.path.join(step_dir, "nep.in"))
|
||
else:
|
||
self.logger.error(f"nep.in template missing: {template_nep_in}")
|
||
return
|
||
|
||
# 3. 执行训练
|
||
executor_name = step_conf.get('executor', 'nep_local')
|
||
# 获取nep命令,比如 "nep" 或者 "/path/to/nep"
|
||
# 注意:nep 命令通常不需要参数,它会自动读取 nep.in
|
||
|
||
self.logger.info(f"Starting NEP training using {executor_name}...")
|
||
# 这里的 log 文件叫 train_exec.log
|
||
if not run_cmd_with_log("nep", step_dir,
|
||
"train_exec.log"): # 假设 cmd 是 nep,如果 machine.yaml 里有特殊定义请调整
|
||
self.logger.error("NEP training failed.")
|
||
return
|
||
|
||
# 检查是否生成了 nep.txt
|
||
if os.path.exists(os.path.join(step_dir, "nep.txt")):
|
||
self.logger.info("Training finished. nep.txt generated.")
|
||
# 更新全局势函数路径,供下一轮 MD 使用
|
||
self.current_nep_pot = os.path.join(step_dir, "nep.txt")
|
||
else:
|
||
self.logger.error("nep.txt not found after training.")
|
||
return
|
||
|
||
# 4. 后处理:绘图与归档
|
||
self.logger.info("Generating plots (gpumdkit.sh -plt train)...")
|
||
kit_path = self.machine.config['paths'].get('gpumdkit', 'gpumdkit.sh')
|
||
cmd_plt = f"{kit_path} -plt train"
|
||
|
||
run_cmd_with_log(cmd_plt, step_dir, "plot.log")
|
||
|
||
# [修改] 创建 output 目录在当前 iter 内部
|
||
output_dir = os.path.join(iter_path, "05.output")
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
# 移动 png 图片
|
||
for file in os.listdir(step_dir):
|
||
if file.endswith(".png"):
|
||
src_png = os.path.join(step_dir, file)
|
||
dst_png = os.path.join(output_dir, file)
|
||
shutil.copy(src_png, dst_png)
|
||
self.logger.info(f"Archived plot: {file}")
|
||
|
||
self.tracker.mark_done(task_id_train)
|
||
|
||
else:
|
||
self.logger.info("Skipping Train (Already Done).")
|
||
# 恢复变量状态
|
||
self.current_nep_pot = os.path.join(step_dir, "nep.txt")
|
||
self.current_train_set = os.path.join(step_dir, "train.xyz")
|