154 lines
6.8 KiB
Python
154 lines
6.8 KiB
Python
# src/workflow.py
|
||
import os
|
||
import shutil
|
||
import logging
|
||
from src.utils import load_yaml
|
||
from src.machine import MachineManager
|
||
from src.steps import MDStep, SelectStep, SCFStep, TrainStep
|
||
|
||
|
||
class Workflow:
|
||
def __init__(self, root_dir):
|
||
self.root_dir = root_dir
|
||
|
||
# 1. 加载配置
|
||
self.param = load_yaml(os.path.join(root_dir, "config/param.yaml"))
|
||
|
||
# 2. 初始化机器管理器
|
||
self.machine = MachineManager(os.path.join(root_dir, "config/machine.yaml"))
|
||
|
||
# 3. 初始化路径变量
|
||
self.workspace = os.path.join(root_dir, "workspace")
|
||
self.data_dir = os.path.join(root_dir, "data")
|
||
self.template_dir = os.path.join(root_dir, "template")
|
||
self.logger = logging.getLogger()
|
||
|
||
# 状态追踪变量
|
||
self.current_nep_pot = os.path.join(self.data_dir, self.param['files']['initial_pot'])
|
||
# 假设第一轮之前的 train set 也是空的或者由用户提供,这里先指向一个基础文件
|
||
self.current_train_set = os.path.join(self.workspace, "accumulated_train.xyz")
|
||
|
||
def run(self):
|
||
self.logger.info(f"Workflow Started: {self.param['project']}")
|
||
|
||
# 遍历每一轮迭代
|
||
for iteration in self.param['iterations']:
|
||
iter_id = iteration['id']
|
||
iter_name = f"iter_{iter_id:02d}"
|
||
iter_path = os.path.join(self.workspace, iter_name)
|
||
|
||
self.logger.info(f"\n >>> Starting Iteration: {iter_id} <<<")
|
||
os.makedirs(iter_path, exist_ok=True)
|
||
|
||
# --- 执行该轮定义的各个 Step ---
|
||
for step_conf in iteration['steps']:
|
||
step_name = step_conf['name']
|
||
|
||
# ==========================
|
||
# Step: 00.md
|
||
# ==========================
|
||
if step_name == "00.md":
|
||
# 1. 【修正点】先定义 step_dir,确保后续都能访问到
|
||
step_dir = os.path.join(iter_path, "00.md")
|
||
|
||
# --- 第一轮初始化:POSCAR -> model.xyz ---
|
||
if iter_id == 0:
|
||
# 2. 现在使用 step_dir 是安全的
|
||
os.makedirs(step_dir, exist_ok=True)
|
||
|
||
# 获取文件名和路径
|
||
poscar_name = self.param['files']['poscar']
|
||
poscar_src = os.path.join(self.data_dir, poscar_name)
|
||
|
||
# 复制原始 POSCAR
|
||
if not os.path.exists(poscar_src):
|
||
self.logger.error(f"Initial POSCAR not found: {poscar_src}")
|
||
return
|
||
shutil.copy(poscar_src, os.path.join(step_dir, poscar_name))
|
||
|
||
# 获取标签
|
||
atom_labels = self.param['files'].get('label', '')
|
||
if not atom_labels:
|
||
self.logger.error("Missing 'label' in param.yaml files section.")
|
||
return
|
||
|
||
# 执行转换
|
||
kit_path = self.machine.config['paths'].get('gpumdkit', 'gpumdkit.sh')
|
||
# 确保使用绝对路径,防止 subprocess 找不到
|
||
if not os.path.isabs(kit_path):
|
||
# 假设相对于项目根目录,或者在 PATH 中
|
||
pass
|
||
|
||
cmd = f"{kit_path} -addlabel {poscar_name} {atom_labels}"
|
||
|
||
self.logger.info(f"Initializing model.xyz with command: {cmd}")
|
||
try:
|
||
subprocess.check_call(cmd, shell=True, cwd=step_dir)
|
||
except subprocess.CalledProcessError as e:
|
||
self.logger.error(f"Failed to convert POSCAR: {e}")
|
||
return
|
||
|
||
if not os.path.exists(os.path.join(step_dir, "model.xyz")):
|
||
self.logger.error("model.xyz was not generated.")
|
||
return
|
||
else:
|
||
self.logger.info("Successfully generated model.xyz")
|
||
|
||
# --- 遍历子任务 (preheat, production...) ---
|
||
for sub in step_conf.get('sub_tasks', []):
|
||
template_sub_name = sub['template_sub']
|
||
sub_work_dir = os.path.join(step_dir, template_sub_name)
|
||
template_path = os.path.join(self.template_dir, "00.md", template_sub_name)
|
||
|
||
# ==========================
|
||
# Step: 01.select
|
||
# ==========================
|
||
elif step_name == "01.select":
|
||
step_dir = os.path.join(iter_path, "01.select")
|
||
select_task = SelectStep("Select", step_dir, self.machine, self.config)
|
||
|
||
# 使用上一步产生的 dump 和 当前的训练集/势函数
|
||
select_task.run(
|
||
dump_path=getattr(self, 'last_dump_path', None),
|
||
train_path=self.current_train_set,
|
||
nep_path=self.current_nep_pot,
|
||
method=step_conf.get('method'),
|
||
params=step_conf.get('params')
|
||
)
|
||
|
||
# ==========================
|
||
# Step: 02.scf
|
||
# ==========================
|
||
elif step_name == "02.scf":
|
||
step_dir = os.path.join(iter_path, "02.scf")
|
||
scf_task = SCFStep("SCF", step_dir, self.machine, self.config)
|
||
|
||
template_path = os.path.join(self.template_dir, "02.scf")
|
||
potcar_path = os.path.join(self.data_dir, self.param['files']['potcar'])
|
||
|
||
scf_task.run(template_path, potcar_path)
|
||
|
||
# 假装产生了一些新数据
|
||
self.new_data_chunk = os.path.join(step_dir, "scf_results.xyz")
|
||
|
||
# ==========================
|
||
# Step: 03.train
|
||
# ==========================
|
||
elif step_name == "03.train":
|
||
step_dir = os.path.join(iter_path, "03.train")
|
||
train_task = TrainStep("Train", step_dir, self.machine, self.config)
|
||
|
||
template_path = os.path.join(self.template_dir, "03.train")
|
||
|
||
# 实际逻辑应该是把 self.new_data_chunk 合并到 total_train.xyz
|
||
# 这里直接传入
|
||
train_task.run(template_path, getattr(self, 'new_data_chunk', None))
|
||
|
||
# 更新当前势函数路径,供下一轮使用
|
||
self.current_nep_pot = os.path.join(step_dir, "nep.txt")
|
||
|
||
self.logger.info("Workflow Finished Successfully.")
|
||
|
||
@property
|
||
def config(self):
|
||
return self.param # 简单透传 |