nep框架重构 00.md增加自定义POSCAR,03.train增加了并入之前的训练结果
This commit is contained in:
105
src/workflow.py
105
src/workflow.py
@@ -48,13 +48,41 @@ class Workflow:
|
||||
# ==========================
|
||||
if step_name == "00.md":
|
||||
step_dir = os.path.join(iter_path, "00.md")
|
||||
|
||||
# 1. 初始化 model.xyz (仅做一次)
|
||||
task_id_init = f"{iter_name}.00.md.init"
|
||||
|
||||
if iter_id == 0:
|
||||
if not self.tracker.is_done(task_id_init):
|
||||
os.makedirs(step_dir, exist_ok=True)
|
||||
if not os.path.exists(step_dir):
|
||||
os.makedirs(step_dir, exist_ok=True)
|
||||
|
||||
# 获取本轮是否定义了自定义 POSCAR
|
||||
custom_poscar = iteration.get('custom_poscar')
|
||||
|
||||
if not self.tracker.is_done(task_id_init):
|
||||
# === 情况 A: 用户指定了自定义 POSCAR (优先级最高) ===
|
||||
if custom_poscar:
|
||||
self.logger.info(f"Using Custom POSCAR for this iteration: {custom_poscar}")
|
||||
poscar_src = os.path.join(self.data_dir, custom_poscar)
|
||||
|
||||
if os.path.exists(poscar_src):
|
||||
# 复制并重命名为 config 中定义的标准名 (为了方便 gpumdkit 处理)
|
||||
std_poscar_name = self.param['files']['poscar']
|
||||
shutil.copy(poscar_src, os.path.join(step_dir, std_poscar_name))
|
||||
|
||||
# 调用 gpumdkit 转化
|
||||
atom_labels = self.param['files'].get('label', '')
|
||||
kit_path = self.machine.config['paths'].get('gpumdkit', 'gpumdkit.sh')
|
||||
cmd = f"{kit_path} -addlabel {std_poscar_name} {atom_labels}"
|
||||
|
||||
if run_cmd_with_log(cmd, step_dir, "init.log"):
|
||||
self.tracker.mark_done(task_id_init)
|
||||
else:
|
||||
self.logger.error("Custom POSCAR initialization failed.")
|
||||
return
|
||||
else:
|
||||
self.logger.error(f"Custom POSCAR not found in data dir: {poscar_src}")
|
||||
return
|
||||
|
||||
# === 情况 B: 第一轮且无自定义 (使用默认 POSCAR) ===
|
||||
elif iter_id == 0:
|
||||
poscar_name = self.param['files']['poscar']
|
||||
poscar_src = os.path.join(self.data_dir, poscar_name)
|
||||
|
||||
@@ -62,38 +90,33 @@ class Workflow:
|
||||
shutil.copy(poscar_src, os.path.join(step_dir, poscar_name))
|
||||
atom_labels = self.param['files'].get('label', '')
|
||||
kit_path = self.machine.config['paths'].get('gpumdkit', 'gpumdkit.sh')
|
||||
|
||||
cmd = f"{kit_path} -addlabel {poscar_name} {atom_labels}"
|
||||
self.logger.info(f"Initializing model.xyz...")
|
||||
|
||||
if run_cmd_with_log(cmd, step_dir, "init.log"):
|
||||
self.tracker.mark_done(task_id_init)
|
||||
else:
|
||||
self.logger.error("Initialization failed. Check iter_00/00.md/init.log")
|
||||
self.logger.error("Initialization failed.")
|
||||
return
|
||||
else:
|
||||
self.logger.error("POSCAR missing.")
|
||||
self.logger.error("Default POSCAR missing.")
|
||||
return
|
||||
|
||||
# === 情况 C: 后续轮次且无自定义 (继承上一轮) ===
|
||||
else:
|
||||
self.logger.info("Skipping Init (Already Done).")
|
||||
else:
|
||||
# --- [新增逻辑] 后续轮次:从上一轮复制 model.xyz ---
|
||||
# 只有当 model.xyz 不存在时才去复制
|
||||
if not os.path.exists(os.path.join(step_dir, "model.xyz")):
|
||||
prev_iter_name = f"iter_{iter_id - 1:02d}"
|
||||
prev_model_src = os.path.join(self.workspace, prev_iter_name, "00.md", "model.xyz")
|
||||
|
||||
# 1. 【核心修复】必须先确保目标文件夹存在!
|
||||
if not os.path.exists(step_dir):
|
||||
os.makedirs(step_dir, exist_ok=True)
|
||||
|
||||
# 2. 检查当前目录是否已有 model.xyz
|
||||
if not os.path.exists(os.path.join(step_dir, "model.xyz")):
|
||||
prev_iter_name = f"iter_{iter_id - 1:02d}"
|
||||
prev_model_src = os.path.join(self.workspace, prev_iter_name, "00.md", "model.xyz")
|
||||
|
||||
if os.path.exists(prev_model_src):
|
||||
self.logger.info(f"Copying model.xyz from {prev_iter_name}...")
|
||||
shutil.copy(prev_model_src, os.path.join(step_dir, "model.xyz"))
|
||||
if os.path.exists(prev_model_src):
|
||||
self.logger.info(f"Copying model.xyz from {prev_iter_name}...")
|
||||
shutil.copy(prev_model_src, os.path.join(step_dir, "model.xyz"))
|
||||
self.tracker.mark_done(task_id_init) # 标记完成
|
||||
else:
|
||||
self.logger.error(f"Previous model.xyz not found: {prev_model_src}")
|
||||
return
|
||||
else:
|
||||
self.logger.error(f"Previous model.xyz not found: {prev_model_src}")
|
||||
return
|
||||
self.tracker.mark_done(task_id_init)
|
||||
# 确保 gpumdkit 路径可用
|
||||
kit_path = self.machine.config['paths'].get('gpumdkit', 'gpumdkit.sh')
|
||||
|
||||
@@ -548,11 +571,23 @@ class Workflow:
|
||||
with open(self.current_train_set, 'r') as infile:
|
||||
shutil.copyfileobj(infile, outfile)
|
||||
|
||||
# B. 写入新数据 (来自本轮 SCF)
|
||||
# 尝试获取变量,如果变量丢失则尝试从路径推断
|
||||
# B. [新增] 注入 data 中的额外数据 (仅在第一轮注入,防止重复)
|
||||
# 如果 extra_train_data 存在,且当前是第一轮 (或者 current_train_set 还没建立)
|
||||
extra_files = self.param['files'].get('extra_train_data', [])
|
||||
if iter_id == 0 and extra_files:
|
||||
self.logger.info(f"Injecting extra training data: {extra_files}")
|
||||
for xyz_file in extra_files:
|
||||
src_xyz = os.path.join(self.data_dir, xyz_file)
|
||||
if os.path.exists(src_xyz):
|
||||
self.logger.info(f" -> Appending {xyz_file}")
|
||||
with open(src_xyz, 'r') as infile:
|
||||
shutil.copyfileobj(infile, outfile)
|
||||
else:
|
||||
self.logger.warning(f"Extra data file not found: {src_xyz}")
|
||||
|
||||
# C. 写入新数据 (来自本轮 SCF)
|
||||
new_data = getattr(self, 'new_data_chunk', None)
|
||||
if not new_data:
|
||||
# 推断路径: iter_XX/02.scf/NEPdataset-multiple_frames/NEP-dataset.xyz
|
||||
new_data = os.path.join(iter_path, "02.scf", "NEPdataset-multiple_frames",
|
||||
"NEP-dataset.xyz")
|
||||
|
||||
@@ -561,12 +596,14 @@ class Workflow:
|
||||
with open(new_data, 'r') as infile:
|
||||
shutil.copyfileobj(infile, outfile)
|
||||
else:
|
||||
if iter_id == 0 and not os.path.exists(self.current_train_set):
|
||||
self.logger.error("No training data available (neither previous nor new).")
|
||||
# 只有在既没有旧数据,又没有额外数据,也没有新数据时才报错
|
||||
has_data = (os.path.exists(self.current_train_set) or
|
||||
(iter_id == 0 and extra_files))
|
||||
if not has_data:
|
||||
self.logger.error("No training data available at all.")
|
||||
return
|
||||
else:
|
||||
self.logger.warning("No new data found from SCF step. Training on old data only.")
|
||||
|
||||
elif not new_data or not os.path.exists(new_data):
|
||||
self.logger.warning("No new SCF data found. Training on existing/extra data only.")
|
||||
# 更新全局变量指向最新的 train.xyz,供下一轮使用
|
||||
self.current_train_set = train_xyz_path
|
||||
|
||||
|
||||
Reference in New Issue
Block a user