Files
NEP-auto/src/workflow.py
2025-12-09 09:39:32 +08:00

154 lines
6.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# src/workflow.py
import os
import shutil
import logging
from src.utils import load_yaml
from src.machine import MachineManager
from src.steps import MDStep, SelectStep, SCFStep, TrainStep
class Workflow:
def __init__(self, root_dir):
self.root_dir = root_dir
# 1. 加载配置
self.param = load_yaml(os.path.join(root_dir, "config/param.yaml"))
# 2. 初始化机器管理器
self.machine = MachineManager(os.path.join(root_dir, "config/machine.yaml"))
# 3. 初始化路径变量
self.workspace = os.path.join(root_dir, "workspace")
self.data_dir = os.path.join(root_dir, "data")
self.template_dir = os.path.join(root_dir, "template")
self.logger = logging.getLogger()
# 状态追踪变量
self.current_nep_pot = os.path.join(self.data_dir, self.param['files']['initial_pot'])
# 假设第一轮之前的 train set 也是空的或者由用户提供,这里先指向一个基础文件
self.current_train_set = os.path.join(self.workspace, "accumulated_train.xyz")
def run(self):
self.logger.info(f"Workflow Started: {self.param['project']}")
# 遍历每一轮迭代
for iteration in self.param['iterations']:
iter_id = iteration['id']
iter_name = f"iter_{iter_id:02d}"
iter_path = os.path.join(self.workspace, iter_name)
self.logger.info(f"\n >>> Starting Iteration: {iter_id} <<<")
os.makedirs(iter_path, exist_ok=True)
# --- 执行该轮定义的各个 Step ---
for step_conf in iteration['steps']:
step_name = step_conf['name']
# ==========================
# Step: 00.md
# ==========================
if step_name == "00.md":
# 1. 【修正点】先定义 step_dir确保后续都能访问到
step_dir = os.path.join(iter_path, "00.md")
# --- 第一轮初始化POSCAR -> model.xyz ---
if iter_id == 0:
# 2. 现在使用 step_dir 是安全的
os.makedirs(step_dir, exist_ok=True)
# 获取文件名和路径
poscar_name = self.param['files']['poscar']
poscar_src = os.path.join(self.data_dir, poscar_name)
# 复制原始 POSCAR
if not os.path.exists(poscar_src):
self.logger.error(f"Initial POSCAR not found: {poscar_src}")
return
shutil.copy(poscar_src, os.path.join(step_dir, poscar_name))
# 获取标签
atom_labels = self.param['files'].get('label', '')
if not atom_labels:
self.logger.error("Missing 'label' in param.yaml files section.")
return
# 执行转换
kit_path = self.machine.config['paths'].get('gpumdkit', 'gpumdkit.sh')
# 确保使用绝对路径,防止 subprocess 找不到
if not os.path.isabs(kit_path):
# 假设相对于项目根目录,或者在 PATH 中
pass
cmd = f"{kit_path} -addlabel {poscar_name} {atom_labels}"
self.logger.info(f"Initializing model.xyz with command: {cmd}")
try:
subprocess.check_call(cmd, shell=True, cwd=step_dir)
except subprocess.CalledProcessError as e:
self.logger.error(f"Failed to convert POSCAR: {e}")
return
if not os.path.exists(os.path.join(step_dir, "model.xyz")):
self.logger.error("model.xyz was not generated.")
return
else:
self.logger.info("Successfully generated model.xyz")
# --- 遍历子任务 (preheat, production...) ---
for sub in step_conf.get('sub_tasks', []):
template_sub_name = sub['template_sub']
sub_work_dir = os.path.join(step_dir, template_sub_name)
template_path = os.path.join(self.template_dir, "00.md", template_sub_name)
# ==========================
# Step: 01.select
# ==========================
elif step_name == "01.select":
step_dir = os.path.join(iter_path, "01.select")
select_task = SelectStep("Select", step_dir, self.machine, self.config)
# 使用上一步产生的 dump 和 当前的训练集/势函数
select_task.run(
dump_path=getattr(self, 'last_dump_path', None),
train_path=self.current_train_set,
nep_path=self.current_nep_pot,
method=step_conf.get('method'),
params=step_conf.get('params')
)
# ==========================
# Step: 02.scf
# ==========================
elif step_name == "02.scf":
step_dir = os.path.join(iter_path, "02.scf")
scf_task = SCFStep("SCF", step_dir, self.machine, self.config)
template_path = os.path.join(self.template_dir, "02.scf")
potcar_path = os.path.join(self.data_dir, self.param['files']['potcar'])
scf_task.run(template_path, potcar_path)
# 假装产生了一些新数据
self.new_data_chunk = os.path.join(step_dir, "scf_results.xyz")
# ==========================
# Step: 03.train
# ==========================
elif step_name == "03.train":
step_dir = os.path.join(iter_path, "03.train")
train_task = TrainStep("Train", step_dir, self.machine, self.config)
template_path = os.path.join(self.template_dir, "03.train")
# 实际逻辑应该是把 self.new_data_chunk 合并到 total_train.xyz
# 这里直接传入
train_task.run(template_path, getattr(self, 'new_data_chunk', None))
# 更新当前势函数路径,供下一轮使用
self.current_nep_pot = os.path.join(step_dir, "nep.txt")
self.logger.info("Workflow Finished Successfully.")
@property
def config(self):
return self.param # 简单透传