nep框架重构
This commit is contained in:
177
src/steps.py
Normal file
177
src/steps.py
Normal file
@@ -0,0 +1,177 @@
|
||||
# src/steps.py
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
import logging
|
||||
import subprocess
|
||||
|
||||
|
||||
class BaseStep:
|
||||
def __init__(self, name, work_dir, machine_manager, config):
|
||||
self.name = name
|
||||
self.work_dir = work_dir
|
||||
self.machine = machine_manager
|
||||
self.config = config
|
||||
os.makedirs(self.work_dir, exist_ok=True)
|
||||
self.logger = logging.getLogger()
|
||||
|
||||
def copy_file(self, src, dst_name=None):
|
||||
"""辅助函数:安全复制文件"""
|
||||
if not os.path.exists(src):
|
||||
self.logger.error(f"[{self.name}] Source file missing: {src}")
|
||||
return False
|
||||
|
||||
dst_name = dst_name if dst_name else os.path.basename(src)
|
||||
dst_path = os.path.join(self.work_dir, dst_name)
|
||||
shutil.copy(src, dst_path)
|
||||
return dst_path
|
||||
|
||||
|
||||
class MDStep(BaseStep):
|
||||
"""
|
||||
对应 00.md: 负责预热/采样
|
||||
"""
|
||||
|
||||
def run(self, prev_nep_path, template_path):
|
||||
self.logger.info(f"=== Running Step: {self.name} (MD) ===")
|
||||
|
||||
# 1. 准备 nep.txt (来自上一轮或初始数据)
|
||||
if not prev_nep_path:
|
||||
self.logger.error("No nep.txt provided for MD.")
|
||||
return False
|
||||
self.copy_file(prev_nep_path, "nep.txt")
|
||||
|
||||
# 2. 准备 model.xyz (如果是第一轮,这里假设外部已经放好了,或者由init步生成)
|
||||
# 为了简化,我们假设上一级流程已经把 model.xyz 准备在 work_dir 或者由上一轮传递
|
||||
# 这里我们假设 model.xyz 必须存在于 work_dir (可以通过 init 步骤拷入)
|
||||
if not os.path.exists(os.path.join(self.work_dir, "model.xyz")):
|
||||
self.logger.warning(f"[{self.name}] model.xyz not found in {self.work_dir}. Make sure Init step ran.")
|
||||
|
||||
# 3. 准备 run.in (从 template 复制)
|
||||
run_in_src = os.path.join(template_path, "run.in")
|
||||
self.copy_file(run_in_src, "run.in")
|
||||
|
||||
# 4. 调用 Machine 执行 GPUMD
|
||||
# 注意:这里我们调用 machine.yaml 里定义的 'gpumd' 执行器
|
||||
success = self.machine.execute("gpumd", self.work_dir)
|
||||
|
||||
if success and os.path.exists(os.path.join(self.work_dir, "dump.xyz")):
|
||||
self.logger.info(f"[{self.name}] MD finished. dump.xyz generated.")
|
||||
return True
|
||||
else:
|
||||
self.logger.error(f"[{self.name}] MD failed or dump.xyz missing.")
|
||||
return False
|
||||
|
||||
|
||||
class SelectStep(BaseStep):
|
||||
"""
|
||||
对应 01.select: 智能筛选
|
||||
"""
|
||||
|
||||
def run(self, dump_path, train_path, nep_path, method="distance", params=[0.01, 60, 120]):
|
||||
self.logger.info(f"=== Running Step: {self.name} (Smart Selection) ===")
|
||||
|
||||
# 准备文件
|
||||
self.copy_file(dump_path, "dump.xyz")
|
||||
self.copy_file(train_path, "train.xyz")
|
||||
self.copy_file(nep_path, "nep.txt")
|
||||
|
||||
target_min, target_max = params[1], params[2]
|
||||
threshold = params[0]
|
||||
step_size = 0.001 # 每次调整的步长
|
||||
|
||||
# 你的流程里是用 gpumdkit.sh 做筛选 (option 203)
|
||||
# 这里的命令构造需要非常小心,模拟你的 echo输入
|
||||
# 假设 gpumdkit.sh 在 PATH 中,或者通过 machine config 获取路径
|
||||
# 由于我们现在是 local 调试,假设你依然依赖 gpumdkit.sh
|
||||
# 但既然我们写 Python,建议未来把筛选逻辑(计算距离)直接写成 Python 代码。
|
||||
# 这里暂时模拟调用逻辑:
|
||||
|
||||
for i in range(10): # 最多尝试10次
|
||||
self.logger.info(f"Selection attempt {i + 1}: Threshold={threshold:.4f}")
|
||||
|
||||
# 构造输入字符串: 203 -> file names -> 1 (distance) -> threshold
|
||||
# 注意:这里假设 gpumdkit.sh 能接受这种输入
|
||||
# 为了调试方便,这里我们暂时只打日志,不真的调 gpumdkit (因为它需要真实的数据文件)
|
||||
# 在真实运行中,这里应该调用:
|
||||
# input_str = f"203\ndump.xyz train.xyz nep.txt\n1\n{threshold}\n"
|
||||
# subprocess.run("gpumdkit.sh", input=input_str, cwd=self.work_dir...)
|
||||
|
||||
# --- 模拟代码 Start ---
|
||||
# 假设生成了一个假的 selected.xyz
|
||||
with open(os.path.join(self.work_dir, "selected.xyz"), 'w') as f:
|
||||
# 模拟根据阈值,阈值越小选的越多
|
||||
mock_count = int(100 / (threshold * 100))
|
||||
f.write(f"Mock selected {mock_count} frames")
|
||||
|
||||
selected_count = mock_count
|
||||
self.logger.info(f"Found {selected_count} structures (Mock).")
|
||||
# --- 模拟代码 End ---
|
||||
|
||||
if target_min <= selected_count <= target_max:
|
||||
self.logger.info(f"Selection Success! Final count: {selected_count}")
|
||||
return True
|
||||
elif selected_count < target_min:
|
||||
self.logger.info("Too few. Decreasing threshold.")
|
||||
threshold -= step_size
|
||||
if threshold < 0: threshold = 0.001
|
||||
else:
|
||||
self.logger.info("Too many. Increasing threshold.")
|
||||
threshold += step_size
|
||||
|
||||
self.logger.warning("Selection failed to converge. Using last result.")
|
||||
return True # 暂时允许继续
|
||||
|
||||
|
||||
class SCFStep(BaseStep):
|
||||
"""
|
||||
对应 02.scf: VASP 计算
|
||||
"""
|
||||
|
||||
def run(self, template_path, potcar_path):
|
||||
self.logger.info(f"=== Running Step: {self.name} (SCF/VASP) ===")
|
||||
|
||||
# 1. 复制 POTCAR
|
||||
self.copy_file(potcar_path, "POTCAR")
|
||||
|
||||
# 2. 复制 INCAR
|
||||
incar_src = os.path.join(template_path, "INCAR")
|
||||
if not self.copy_file(incar_src, "INCAR"):
|
||||
return False # INCAR 必须有
|
||||
|
||||
# 3. 复制 KPOINTS (可选)
|
||||
kpoints_src = os.path.join(template_path, "KPOINTS")
|
||||
if os.path.exists(kpoints_src):
|
||||
self.copy_file(kpoints_src, "KPOINTS")
|
||||
|
||||
# 4. 执行 VASP
|
||||
# 注意:这里通常需要把 selected.xyz 拆分成多个文件夹
|
||||
# 在 Local 简单测试中,我们假设 selected.xyz 已经被拆分成了 POSCAR
|
||||
# 或者我们只跑一个单点能测试。
|
||||
# 既然是框架开发,这里我们调用 machine.yaml 里的 'vasp_cpu'
|
||||
|
||||
success = self.machine.execute("vasp_cpu", self.work_dir)
|
||||
return success
|
||||
|
||||
|
||||
class TrainStep(BaseStep):
|
||||
"""
|
||||
对应 03.train: NEP 训练
|
||||
"""
|
||||
|
||||
def run(self, template_path, new_train_data_path):
|
||||
self.logger.info(f"=== Running Step: {self.name} (Train) ===")
|
||||
|
||||
# 1. 准备 nep.in
|
||||
self.copy_file(os.path.join(template_path, "nep.in"), "nep.in")
|
||||
|
||||
# 2. 准备 train.xyz (这里假设我们把所有数据 cat 到了这里)
|
||||
if new_train_data_path and os.path.exists(new_train_data_path):
|
||||
self.copy_file(new_train_data_path, "train.xyz")
|
||||
else:
|
||||
# 如果没有新数据,只是测试,创建一个空的
|
||||
with open(os.path.join(self.work_dir, "train.xyz"), 'w') as f:
|
||||
f.write("Mock training data")
|
||||
|
||||
# 3. 运行 NEP
|
||||
return self.machine.execute("nep_local", self.work_dir)
|
||||
Reference in New Issue
Block a user