nep框架重构 00.md增加自定义POSCAR,03.train增加了并入之前的训练结果
This commit is contained in:
@@ -48,13 +48,41 @@ class Workflow:
|
|||||||
# ==========================
|
# ==========================
|
||||||
if step_name == "00.md":
|
if step_name == "00.md":
|
||||||
step_dir = os.path.join(iter_path, "00.md")
|
step_dir = os.path.join(iter_path, "00.md")
|
||||||
|
|
||||||
# 1. 初始化 model.xyz (仅做一次)
|
|
||||||
task_id_init = f"{iter_name}.00.md.init"
|
task_id_init = f"{iter_name}.00.md.init"
|
||||||
|
|
||||||
if iter_id == 0:
|
if not os.path.exists(step_dir):
|
||||||
if not self.tracker.is_done(task_id_init):
|
|
||||||
os.makedirs(step_dir, exist_ok=True)
|
os.makedirs(step_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 获取本轮是否定义了自定义 POSCAR
|
||||||
|
custom_poscar = iteration.get('custom_poscar')
|
||||||
|
|
||||||
|
if not self.tracker.is_done(task_id_init):
|
||||||
|
# === 情况 A: 用户指定了自定义 POSCAR (优先级最高) ===
|
||||||
|
if custom_poscar:
|
||||||
|
self.logger.info(f"Using Custom POSCAR for this iteration: {custom_poscar}")
|
||||||
|
poscar_src = os.path.join(self.data_dir, custom_poscar)
|
||||||
|
|
||||||
|
if os.path.exists(poscar_src):
|
||||||
|
# 复制并重命名为 config 中定义的标准名 (为了方便 gpumdkit 处理)
|
||||||
|
std_poscar_name = self.param['files']['poscar']
|
||||||
|
shutil.copy(poscar_src, os.path.join(step_dir, std_poscar_name))
|
||||||
|
|
||||||
|
# 调用 gpumdkit 转化
|
||||||
|
atom_labels = self.param['files'].get('label', '')
|
||||||
|
kit_path = self.machine.config['paths'].get('gpumdkit', 'gpumdkit.sh')
|
||||||
|
cmd = f"{kit_path} -addlabel {std_poscar_name} {atom_labels}"
|
||||||
|
|
||||||
|
if run_cmd_with_log(cmd, step_dir, "init.log"):
|
||||||
|
self.tracker.mark_done(task_id_init)
|
||||||
|
else:
|
||||||
|
self.logger.error("Custom POSCAR initialization failed.")
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
self.logger.error(f"Custom POSCAR not found in data dir: {poscar_src}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# === 情况 B: 第一轮且无自定义 (使用默认 POSCAR) ===
|
||||||
|
elif iter_id == 0:
|
||||||
poscar_name = self.param['files']['poscar']
|
poscar_name = self.param['files']['poscar']
|
||||||
poscar_src = os.path.join(self.data_dir, poscar_name)
|
poscar_src = os.path.join(self.data_dir, poscar_name)
|
||||||
|
|
||||||
@@ -62,28 +90,20 @@ class Workflow:
|
|||||||
shutil.copy(poscar_src, os.path.join(step_dir, poscar_name))
|
shutil.copy(poscar_src, os.path.join(step_dir, poscar_name))
|
||||||
atom_labels = self.param['files'].get('label', '')
|
atom_labels = self.param['files'].get('label', '')
|
||||||
kit_path = self.machine.config['paths'].get('gpumdkit', 'gpumdkit.sh')
|
kit_path = self.machine.config['paths'].get('gpumdkit', 'gpumdkit.sh')
|
||||||
|
|
||||||
cmd = f"{kit_path} -addlabel {poscar_name} {atom_labels}"
|
cmd = f"{kit_path} -addlabel {poscar_name} {atom_labels}"
|
||||||
self.logger.info(f"Initializing model.xyz...")
|
|
||||||
|
|
||||||
if run_cmd_with_log(cmd, step_dir, "init.log"):
|
if run_cmd_with_log(cmd, step_dir, "init.log"):
|
||||||
self.tracker.mark_done(task_id_init)
|
self.tracker.mark_done(task_id_init)
|
||||||
else:
|
else:
|
||||||
self.logger.error("Initialization failed. Check iter_00/00.md/init.log")
|
self.logger.error("Initialization failed.")
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
self.logger.error("POSCAR missing.")
|
self.logger.error("Default POSCAR missing.")
|
||||||
return
|
return
|
||||||
else:
|
|
||||||
self.logger.info("Skipping Init (Already Done).")
|
|
||||||
else:
|
|
||||||
# --- [新增逻辑] 后续轮次:从上一轮复制 model.xyz ---
|
|
||||||
|
|
||||||
# 1. 【核心修复】必须先确保目标文件夹存在!
|
# === 情况 C: 后续轮次且无自定义 (继承上一轮) ===
|
||||||
if not os.path.exists(step_dir):
|
else:
|
||||||
os.makedirs(step_dir, exist_ok=True)
|
# 只有当 model.xyz 不存在时才去复制
|
||||||
|
|
||||||
# 2. 检查当前目录是否已有 model.xyz
|
|
||||||
if not os.path.exists(os.path.join(step_dir, "model.xyz")):
|
if not os.path.exists(os.path.join(step_dir, "model.xyz")):
|
||||||
prev_iter_name = f"iter_{iter_id - 1:02d}"
|
prev_iter_name = f"iter_{iter_id - 1:02d}"
|
||||||
prev_model_src = os.path.join(self.workspace, prev_iter_name, "00.md", "model.xyz")
|
prev_model_src = os.path.join(self.workspace, prev_iter_name, "00.md", "model.xyz")
|
||||||
@@ -91,9 +111,12 @@ class Workflow:
|
|||||||
if os.path.exists(prev_model_src):
|
if os.path.exists(prev_model_src):
|
||||||
self.logger.info(f"Copying model.xyz from {prev_iter_name}...")
|
self.logger.info(f"Copying model.xyz from {prev_iter_name}...")
|
||||||
shutil.copy(prev_model_src, os.path.join(step_dir, "model.xyz"))
|
shutil.copy(prev_model_src, os.path.join(step_dir, "model.xyz"))
|
||||||
|
self.tracker.mark_done(task_id_init) # 标记完成
|
||||||
else:
|
else:
|
||||||
self.logger.error(f"Previous model.xyz not found: {prev_model_src}")
|
self.logger.error(f"Previous model.xyz not found: {prev_model_src}")
|
||||||
return
|
return
|
||||||
|
else:
|
||||||
|
self.tracker.mark_done(task_id_init)
|
||||||
# 确保 gpumdkit 路径可用
|
# 确保 gpumdkit 路径可用
|
||||||
kit_path = self.machine.config['paths'].get('gpumdkit', 'gpumdkit.sh')
|
kit_path = self.machine.config['paths'].get('gpumdkit', 'gpumdkit.sh')
|
||||||
|
|
||||||
@@ -548,11 +571,23 @@ class Workflow:
|
|||||||
with open(self.current_train_set, 'r') as infile:
|
with open(self.current_train_set, 'r') as infile:
|
||||||
shutil.copyfileobj(infile, outfile)
|
shutil.copyfileobj(infile, outfile)
|
||||||
|
|
||||||
# B. 写入新数据 (来自本轮 SCF)
|
# B. [新增] 注入 data 中的额外数据 (仅在第一轮注入,防止重复)
|
||||||
# 尝试获取变量,如果变量丢失则尝试从路径推断
|
# 如果 extra_train_data 存在,且当前是第一轮 (或者 current_train_set 还没建立)
|
||||||
|
extra_files = self.param['files'].get('extra_train_data', [])
|
||||||
|
if iter_id == 0 and extra_files:
|
||||||
|
self.logger.info(f"Injecting extra training data: {extra_files}")
|
||||||
|
for xyz_file in extra_files:
|
||||||
|
src_xyz = os.path.join(self.data_dir, xyz_file)
|
||||||
|
if os.path.exists(src_xyz):
|
||||||
|
self.logger.info(f" -> Appending {xyz_file}")
|
||||||
|
with open(src_xyz, 'r') as infile:
|
||||||
|
shutil.copyfileobj(infile, outfile)
|
||||||
|
else:
|
||||||
|
self.logger.warning(f"Extra data file not found: {src_xyz}")
|
||||||
|
|
||||||
|
# C. 写入新数据 (来自本轮 SCF)
|
||||||
new_data = getattr(self, 'new_data_chunk', None)
|
new_data = getattr(self, 'new_data_chunk', None)
|
||||||
if not new_data:
|
if not new_data:
|
||||||
# 推断路径: iter_XX/02.scf/NEPdataset-multiple_frames/NEP-dataset.xyz
|
|
||||||
new_data = os.path.join(iter_path, "02.scf", "NEPdataset-multiple_frames",
|
new_data = os.path.join(iter_path, "02.scf", "NEPdataset-multiple_frames",
|
||||||
"NEP-dataset.xyz")
|
"NEP-dataset.xyz")
|
||||||
|
|
||||||
@@ -561,12 +596,14 @@ class Workflow:
|
|||||||
with open(new_data, 'r') as infile:
|
with open(new_data, 'r') as infile:
|
||||||
shutil.copyfileobj(infile, outfile)
|
shutil.copyfileobj(infile, outfile)
|
||||||
else:
|
else:
|
||||||
if iter_id == 0 and not os.path.exists(self.current_train_set):
|
# 只有在既没有旧数据,又没有额外数据,也没有新数据时才报错
|
||||||
self.logger.error("No training data available (neither previous nor new).")
|
has_data = (os.path.exists(self.current_train_set) or
|
||||||
|
(iter_id == 0 and extra_files))
|
||||||
|
if not has_data:
|
||||||
|
self.logger.error("No training data available at all.")
|
||||||
return
|
return
|
||||||
else:
|
elif not new_data or not os.path.exists(new_data):
|
||||||
self.logger.warning("No new data found from SCF step. Training on old data only.")
|
self.logger.warning("No new SCF data found. Training on existing/extra data only.")
|
||||||
|
|
||||||
# 更新全局变量指向最新的 train.xyz,供下一轮使用
|
# 更新全局变量指向最新的 train.xyz,供下一轮使用
|
||||||
self.current_train_set = train_xyz_path
|
self.current_train_set = train_xyz_path
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user