NEP框架重构00阶段
This commit is contained in:
31
src/state.py
Normal file
31
src/state.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
# src/state.py
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
|
class StateTracker:
|
||||||
|
def __init__(self, workspace_dir):
|
||||||
|
self.state_file = os.path.join(workspace_dir, "workflow_status.json")
|
||||||
|
self.completed_tasks = set()
|
||||||
|
self.load()
|
||||||
|
|
||||||
|
def load(self):
|
||||||
|
if os.path.exists(self.state_file):
|
||||||
|
try:
|
||||||
|
with open(self.state_file, 'r') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
self.completed_tasks = set(data.get("completed", []))
|
||||||
|
except:
|
||||||
|
self.completed_tasks = set()
|
||||||
|
|
||||||
|
def mark_done(self, task_id):
|
||||||
|
"""标记任务完成并保存"""
|
||||||
|
self.completed_tasks.add(task_id)
|
||||||
|
self.save()
|
||||||
|
|
||||||
|
def is_done(self, task_id):
|
||||||
|
"""检查任务是否已完成"""
|
||||||
|
return task_id in self.completed_tasks
|
||||||
|
|
||||||
|
def save(self):
|
||||||
|
with open(self.state_file, 'w') as f:
|
||||||
|
json.dump({"completed": list(self.completed_tasks)}, f, indent=2)
|
||||||
38
src/utils.py
38
src/utils.py
@@ -45,3 +45,41 @@ class Notifier:
|
|||||||
def send(self, title, msg, priority=5):
|
def send(self, title, msg, priority=5):
|
||||||
# 暂时只打印日志,不实际发送
|
# 暂时只打印日志,不实际发送
|
||||||
logging.info(f"[[Notification]] {title}: {msg}")
|
logging.info(f"[[Notification]] {title}: {msg}")
|
||||||
|
|
||||||
|
|
||||||
|
# src/utils.py 添加在最后
|
||||||
|
|
||||||
|
def run_cmd_with_log(cmd, cwd, log_file="exec.log", input_str=None):
|
||||||
|
"""
|
||||||
|
执行命令并将 stdout/stderr 重定向到日志文件
|
||||||
|
"""
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
log_path = os.path.join(cwd, log_file)
|
||||||
|
mode = 'a' if os.path.exists(log_path) else 'w'
|
||||||
|
|
||||||
|
with open(log_path, mode) as f:
|
||||||
|
f.write(f"\n\n>>> Executing: {cmd}\n")
|
||||||
|
f.write(f">>> Input: {repr(input_str)}\n")
|
||||||
|
f.write("-" * 40 + "\n")
|
||||||
|
f.flush()
|
||||||
|
|
||||||
|
try:
|
||||||
|
process = subprocess.Popen(
|
||||||
|
cmd,
|
||||||
|
shell=True,
|
||||||
|
cwd=cwd,
|
||||||
|
stdin=subprocess.PIPE if input_str else None,
|
||||||
|
stdout=f, # 直接指向文件
|
||||||
|
stderr=subprocess.STDOUT, # 把错误也合并到同一个日志
|
||||||
|
text=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# 发送输入并等待
|
||||||
|
process.communicate(input=input_str)
|
||||||
|
|
||||||
|
f.write(f"\n>>> Finished with Return Code: {process.returncode}\n")
|
||||||
|
return process.returncode == 0
|
||||||
|
except Exception as e:
|
||||||
|
f.write(f"\n>>> Exception: {str(e)}\n")
|
||||||
|
return False
|
||||||
206
src/workflow.py
206
src/workflow.py
@@ -2,45 +2,44 @@
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import logging
|
import logging
|
||||||
from src.utils import load_yaml
|
|
||||||
from src.machine import MachineManager
|
|
||||||
from src.steps import MDStep, SelectStep, SCFStep, TrainStep
|
|
||||||
import subprocess
|
import subprocess
|
||||||
|
|
||||||
|
from src.utils import load_yaml, run_cmd_with_log
|
||||||
|
from src.machine import MachineManager
|
||||||
|
from src.state import StateTracker # 新增
|
||||||
|
from src.steps import MDStep, SelectStep, SCFStep, TrainStep
|
||||||
|
|
||||||
|
|
||||||
class Workflow:
|
class Workflow:
|
||||||
def __init__(self, root_dir):
|
def __init__(self, root_dir):
|
||||||
self.root_dir = root_dir
|
self.root_dir = root_dir
|
||||||
|
|
||||||
# 1. 加载配置
|
|
||||||
self.param = load_yaml(os.path.join(root_dir, "config/param.yaml"))
|
self.param = load_yaml(os.path.join(root_dir, "config/param.yaml"))
|
||||||
|
|
||||||
# 2. 初始化机器管理器
|
|
||||||
self.machine = MachineManager(os.path.join(root_dir, "config/machine.yaml"))
|
self.machine = MachineManager(os.path.join(root_dir, "config/machine.yaml"))
|
||||||
|
|
||||||
# 3. 初始化路径变量
|
|
||||||
self.workspace = os.path.join(root_dir, "workspace")
|
self.workspace = os.path.join(root_dir, "workspace")
|
||||||
self.data_dir = os.path.join(root_dir, "data")
|
self.data_dir = os.path.join(root_dir, "data")
|
||||||
self.template_dir = os.path.join(root_dir, "template")
|
self.template_dir = os.path.join(root_dir, "template")
|
||||||
|
|
||||||
self.logger = logging.getLogger()
|
self.logger = logging.getLogger()
|
||||||
|
|
||||||
# 状态追踪变量
|
# 初始化状态追踪
|
||||||
|
os.makedirs(self.workspace, exist_ok=True)
|
||||||
|
self.tracker = StateTracker(self.workspace)
|
||||||
|
|
||||||
|
# 初始变量
|
||||||
self.current_nep_pot = os.path.join(self.data_dir, self.param['files']['initial_pot'])
|
self.current_nep_pot = os.path.join(self.data_dir, self.param['files']['initial_pot'])
|
||||||
# 假设第一轮之前的 train set 也是空的或者由用户提供,这里先指向一个基础文件
|
|
||||||
self.current_train_set = os.path.join(self.workspace, "accumulated_train.xyz")
|
self.current_train_set = os.path.join(self.workspace, "accumulated_train.xyz")
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
self.logger.info(f"Workflow Started: {self.param['project']}")
|
self.logger.info(f"Workflow Started: {self.param['project']}")
|
||||||
|
|
||||||
# 遍历每一轮迭代
|
|
||||||
for iteration in self.param['iterations']:
|
for iteration in self.param['iterations']:
|
||||||
iter_id = iteration['id']
|
iter_id = iteration['id']
|
||||||
iter_name = f"iter_{iter_id:02d}"
|
iter_name = f"iter_{iter_id:02d}"
|
||||||
iter_path = os.path.join(self.workspace, iter_name)
|
iter_path = os.path.join(self.workspace, iter_name)
|
||||||
|
self.logger.info(f"\n >>> Processing Iteration: {iter_id} <<<")
|
||||||
self.logger.info(f"\n >>> Starting Iteration: {iter_id} <<<")
|
|
||||||
os.makedirs(iter_path, exist_ok=True)
|
os.makedirs(iter_path, exist_ok=True)
|
||||||
|
|
||||||
# --- 执行该轮定义的各个 Step ---
|
|
||||||
for step_conf in iteration['steps']:
|
for step_conf in iteration['steps']:
|
||||||
step_name = step_conf['name']
|
step_name = step_conf['name']
|
||||||
|
|
||||||
@@ -50,140 +49,125 @@ class Workflow:
|
|||||||
if step_name == "00.md":
|
if step_name == "00.md":
|
||||||
step_dir = os.path.join(iter_path, "00.md")
|
step_dir = os.path.join(iter_path, "00.md")
|
||||||
|
|
||||||
# 1. 第一轮初始化:POSCAR -> model.xyz (保持不变)
|
# 1. 初始化 model.xyz (仅做一次)
|
||||||
|
task_id_init = f"{iter_name}.00.md.init"
|
||||||
|
|
||||||
if iter_id == 0:
|
if iter_id == 0:
|
||||||
|
if not self.tracker.is_done(task_id_init):
|
||||||
os.makedirs(step_dir, exist_ok=True)
|
os.makedirs(step_dir, exist_ok=True)
|
||||||
poscar_name = self.param['files']['poscar']
|
poscar_name = self.param['files']['poscar']
|
||||||
poscar_src = os.path.join(self.data_dir, poscar_name)
|
poscar_src = os.path.join(self.data_dir, poscar_name)
|
||||||
|
|
||||||
if os.path.exists(poscar_src):
|
if os.path.exists(poscar_src):
|
||||||
shutil.copy(poscar_src, os.path.join(step_dir, poscar_name))
|
shutil.copy(poscar_src, os.path.join(step_dir, poscar_name))
|
||||||
atom_labels = self.param['files'].get('label', '')
|
atom_labels = self.param['files'].get('label', '')
|
||||||
kit_path = self.machine.config['paths'].get('gpumdkit', 'gpumdkit.sh')
|
kit_path = self.machine.config['paths'].get('gpumdkit', 'gpumdkit.sh')
|
||||||
|
|
||||||
cmd = f"{kit_path} -addlabel {poscar_name} {atom_labels}"
|
cmd = f"{kit_path} -addlabel {poscar_name} {atom_labels}"
|
||||||
self.logger.info(f"Initializing model.xyz: {cmd}")
|
self.logger.info(f"Initializing model.xyz...")
|
||||||
subprocess.check_call(cmd, shell=True, cwd=step_dir)
|
|
||||||
|
if run_cmd_with_log(cmd, step_dir, "init.log"):
|
||||||
|
self.tracker.mark_done(task_id_init)
|
||||||
else:
|
else:
|
||||||
self.logger.error(f"POSCAR missing: {poscar_src}")
|
self.logger.error("Initialization failed. Check iter_00/00.md/init.log")
|
||||||
continue
|
return
|
||||||
|
else:
|
||||||
|
self.logger.error("POSCAR missing.")
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
self.logger.info("Skipping Init (Already Done).")
|
||||||
|
|
||||||
# 确保 gpumdkit 路径可用
|
# 确保 gpumdkit 路径可用
|
||||||
kit_path = self.machine.config['paths'].get('gpumdkit', 'gpumdkit.sh')
|
kit_path = self.machine.config['paths'].get('gpumdkit', 'gpumdkit.sh')
|
||||||
|
|
||||||
# ----------------------------------------------------
|
# === Sub-task 1: Preheat ===
|
||||||
# 2. 核心修改:分别处理 preheat 和 production
|
task_id_preheat = f"{iter_name}.00.md.preheat"
|
||||||
# ----------------------------------------------------
|
|
||||||
|
|
||||||
# === Sub-task 1: Preheat (预热) ===
|
|
||||||
# 逻辑:复制model.xyz -> 跑MD -> 跑201采样 -> 生成 sampled_structures.xyz
|
|
||||||
preheat_dir = os.path.join(step_dir, "preheat")
|
preheat_dir = os.path.join(step_dir, "preheat")
|
||||||
|
|
||||||
|
if not self.tracker.is_done(task_id_preheat):
|
||||||
|
self.logger.info(">>> Starting Preheat...")
|
||||||
os.makedirs(preheat_dir, exist_ok=True)
|
os.makedirs(preheat_dir, exist_ok=True)
|
||||||
|
|
||||||
# 准备文件
|
# 准备文件
|
||||||
current_model_xyz = os.path.join(step_dir, "model.xyz")
|
shutil.copy(os.path.join(step_dir, "model.xyz"), os.path.join(preheat_dir, "model.xyz"))
|
||||||
shutil.copy(current_model_xyz, os.path.join(preheat_dir, "model.xyz"))
|
|
||||||
shutil.copy(self.current_nep_pot, os.path.join(preheat_dir, "nep.txt"))
|
shutil.copy(self.current_nep_pot, os.path.join(preheat_dir, "nep.txt"))
|
||||||
shutil.copy(os.path.join(self.template_dir, "00.md", "preheat", "run.in"),
|
shutil.copy(os.path.join(self.template_dir, "00.md", "preheat", "run.in"),
|
||||||
os.path.join(preheat_dir, "run.in"))
|
os.path.join(preheat_dir, "run.in"))
|
||||||
|
|
||||||
self.logger.info(">>> Running Preheat MD...")
|
# A. 运行 GPUMD
|
||||||
# 使用 Machine 运行 GPUMD (假设 machine.yaml 里 gpumd 是基础命令)
|
# 假设 gpumd 命令直接运行,无输入
|
||||||
self.machine.execute("gpumd", preheat_dir)
|
if not run_cmd_with_log("gpumd", preheat_dir, "step_exec.log"):
|
||||||
|
self.logger.error("Preheat GPUMD failed.")
|
||||||
|
return
|
||||||
|
|
||||||
# [关键] Preheat 后处理:采样
|
# B. 运行 采样 (201)
|
||||||
if os.path.exists(os.path.join(preheat_dir, "dump.xyz")):
|
# [修正] 严格按照要求: "201\ndump.xyz uniform 4" (中间无额外换行)
|
||||||
|
input_str_201 = "201\ndump.xyz uniform 4"
|
||||||
self.logger.info(">>> Running Sampling (201)...")
|
self.logger.info(">>> Running Sampling (201)...")
|
||||||
# 构造命令: echo -e "201\ndump.xyz\nuniform\n4" | gpumdkit.sh
|
|
||||||
# 注意:根据你的描述 "dump.xyz uniform 4",我这里构造输入流
|
|
||||||
# 如果你的脚本交互顺序不同,请调整这里的字符串
|
|
||||||
# 这里的 \n 代表回车
|
|
||||||
input_str = "201\ndump.xyz uniform 4"
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 调用 gpumdkit
|
|
||||||
process = subprocess.Popen(
|
|
||||||
kit_path,
|
|
||||||
shell=True,
|
|
||||||
cwd=preheat_dir,
|
|
||||||
stdin=subprocess.PIPE,
|
|
||||||
stdout=subprocess.PIPE,
|
|
||||||
stderr=subprocess.PIPE,
|
|
||||||
text=True
|
|
||||||
)
|
|
||||||
stdout, stderr = process.communicate(input=input_str)
|
|
||||||
|
|
||||||
|
if run_cmd_with_log(kit_path, preheat_dir, "step_exec.log", input_str=input_str_201):
|
||||||
if os.path.exists(os.path.join(preheat_dir, "sampled_structures.xyz")):
|
if os.path.exists(os.path.join(preheat_dir, "sampled_structures.xyz")):
|
||||||
self.logger.info("Sampled structures generated successfully.")
|
self.tracker.mark_done(task_id_preheat)
|
||||||
else:
|
else:
|
||||||
self.logger.error(f"Sampling failed. Output: {stdout} Error: {stderr}")
|
self.logger.error("sampled_structures.xyz not generated.")
|
||||||
continue # 如果没生成采样文件,后续Production没法做
|
return
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Error executing sampling: {e}")
|
|
||||||
continue
|
|
||||||
else:
|
else:
|
||||||
self.logger.error("Preheat dump.xyz missing.")
|
self.logger.error("Sampling command failed.")
|
||||||
continue
|
return
|
||||||
|
else:
|
||||||
|
self.logger.info("Skipping Preheat (Already Done).")
|
||||||
|
|
||||||
# === Sub-task 2: Production (加工/正式采样) ===
|
# === Sub-task 2: Production ===
|
||||||
# 逻辑:链接 sampled_structures -> 跑302 -> 跑presub.sh
|
task_id_prod = f"{iter_name}.00.md.production"
|
||||||
prod_dir = os.path.join(step_dir, "production")
|
prod_dir = os.path.join(step_dir, "production")
|
||||||
|
|
||||||
|
if not self.tracker.is_done(task_id_prod):
|
||||||
|
self.logger.info(">>> Starting Production...")
|
||||||
os.makedirs(prod_dir, exist_ok=True)
|
os.makedirs(prod_dir, exist_ok=True)
|
||||||
|
|
||||||
# 1. 建立软链接 (sampled_structures.xyz)
|
# 软链接
|
||||||
src_sample = os.path.abspath(os.path.join(preheat_dir, "sampled_structures.xyz"))
|
src_sample = os.path.abspath(os.path.join(preheat_dir, "sampled_structures.xyz"))
|
||||||
dst_sample = os.path.join(prod_dir, "sampled_structures.xyz")
|
dst_sample = os.path.join(prod_dir, "sampled_structures.xyz")
|
||||||
if os.path.exists(dst_sample): os.remove(dst_sample) # 清理旧的
|
if os.path.exists(dst_sample): os.remove(dst_sample)
|
||||||
os.symlink(src_sample, dst_sample)
|
os.symlink(src_sample, dst_sample)
|
||||||
|
|
||||||
# 2. 准备基础文件
|
shutil.copy(self.current_nep_pot, os.path.join(prod_dir, "nep.txt"))
|
||||||
self.logger.error(f"presub.sh execution failed: {e}")
|
shutil.copy(os.path.join(self.template_dir, "00.md", "production", "run.in"),
|
||||||
|
os.path.join(prod_dir, "run.in"))
|
||||||
|
|
||||||
|
# A. 运行 302 生成 presub.sh
|
||||||
|
input_str_302 = "302"
|
||||||
|
if not run_cmd_with_log(kit_path, prod_dir, "step_exec.log", input_str=input_str_302):
|
||||||
|
self.logger.error("302 command failed.")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not os.path.exists(os.path.join(prod_dir, "presub.sh")):
|
||||||
|
self.logger.error("presub.sh not found.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# B. 运行 presub.sh
|
||||||
|
os.chmod(os.path.join(prod_dir, "presub.sh"), 0o755)
|
||||||
|
self.logger.info(">>> Executing presub.sh...")
|
||||||
|
|
||||||
|
if not run_cmd_with_log("./presub.sh", prod_dir, "step_exec.log"):
|
||||||
|
self.logger.error("presub.sh execution failed.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# C. 合并 dump
|
||||||
|
run_cmd_with_log("cat sample_*/dump.xyz > dump.xyz", prod_dir, "step_exec.log")
|
||||||
|
|
||||||
|
self.last_dump_path = os.path.join(prod_dir, "dump.xyz")
|
||||||
|
self.tracker.mark_done(task_id_prod)
|
||||||
|
else:
|
||||||
|
self.logger.info("Skipping Production (Already Done).")
|
||||||
|
# 即使跳过,也要更新变量给下一步用
|
||||||
|
self.last_dump_path = os.path.join(prod_dir, "dump.xyz")
|
||||||
|
|
||||||
|
# ==========================
|
||||||
# Step: 01.select
|
# Step: 01.select
|
||||||
# ==========================
|
# ==========================
|
||||||
elif step_name == "01.select":
|
elif step_name == "01.select":
|
||||||
step_dir = os.path.join(iter_path, "01.select")
|
# 可以在这里也加上 StateTracker 逻辑
|
||||||
select_task = SelectStep("Select", step_dir, self.machine, self.config)
|
pass
|
||||||
|
|
||||||
# 使用上一步产生的 dump 和 当前的训练集/势函数
|
# ... (后续步骤类似,暂时省略)
|
||||||
select_task.run(
|
|
||||||
dump_path=getattr(self, 'last_dump_path', None),
|
|
||||||
train_path=self.current_train_set,
|
|
||||||
nep_path=self.current_nep_pot,
|
|
||||||
method=step_conf.get('method'),
|
|
||||||
params=step_conf.get('params')
|
|
||||||
)
|
|
||||||
|
|
||||||
# ==========================
|
|
||||||
# Step: 02.scf
|
|
||||||
# ==========================
|
|
||||||
elif step_name == "02.scf":
|
|
||||||
step_dir = os.path.join(iter_path, "02.scf")
|
|
||||||
scf_task = SCFStep("SCF", step_dir, self.machine, self.config)
|
|
||||||
|
|
||||||
template_path = os.path.join(self.template_dir, "02.scf")
|
|
||||||
potcar_path = os.path.join(self.data_dir, self.param['files']['potcar'])
|
|
||||||
|
|
||||||
scf_task.run(template_path, potcar_path)
|
|
||||||
|
|
||||||
# 假装产生了一些新数据
|
|
||||||
self.new_data_chunk = os.path.join(step_dir, "scf_results.xyz")
|
|
||||||
|
|
||||||
# ==========================
|
|
||||||
# Step: 03.train
|
|
||||||
# ==========================
|
|
||||||
elif step_name == "03.train":
|
|
||||||
step_dir = os.path.join(iter_path, "03.train")
|
|
||||||
train_task = TrainStep("Train", step_dir, self.machine, self.config)
|
|
||||||
|
|
||||||
template_path = os.path.join(self.template_dir, "03.train")
|
|
||||||
|
|
||||||
# 实际逻辑应该是把 self.new_data_chunk 合并到 total_train.xyz
|
|
||||||
# 这里直接传入
|
|
||||||
train_task.run(template_path, getattr(self, 'new_data_chunk', None))
|
|
||||||
|
|
||||||
# 更新当前势函数路径,供下一轮使用
|
|
||||||
self.current_nep_pot = os.path.join(step_dir, "nep.txt")
|
|
||||||
|
|
||||||
self.logger.info("Workflow Finished Successfully.")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def config(self):
|
|
||||||
return self.param # 简单透传
|
|
||||||
Reference in New Issue
Block a user