nep框架重构 03.train

This commit is contained in:
2025-12-10 19:31:05 +08:00
parent 6ab7aecbe4
commit 518264eb60
4 changed files with 288 additions and 40 deletions

View File

@@ -55,4 +55,10 @@ iterations:
executor: "vasp_std"
- name: "03.train"
executor: "nep_local"
executor: "nep_local"
- name: "04.predict"
# 定义温度和时间列表
conditions:
- {T: 700, time: "1ns"}
- {T: 800, time: "1ns"} # 支持不同温度不同时长
- {T: 900, time: "1ns"}

View File

@@ -82,4 +82,35 @@ def run_cmd_with_log(cmd, cwd, log_file="exec.log", input_str=None):
return process.returncode == 0
except Exception as e:
f.write(f"\n>>> Exception: {str(e)}\n")
return False
return False
def parse_time_to_steps(time_str, time_step_fs=1.0):
"""
解析时间字符串 (e.g., '5ns', '10ps', '50000') 为模拟步数
"""
import re
# 如果纯数字,直接返回整数
if str(time_str).isdigit():
return int(time_str)
# 正则匹配数字和单位
match = re.match(r"([\d\.]+)\s*([a-zA-Z]+)", str(time_str))
if not match:
raise ValueError(f"Unknown time format: {time_str}")
value = float(match.group(1))
unit = match.group(2).lower()
# 基础单位是 fs
if unit == 'fs':
total_fs = value
elif unit == 'ps':
total_fs = value * 1000.0
elif unit == 'ns':
total_fs = value * 1000.0 * 1000.0
else:
raise ValueError(f"Unsupported time unit: {unit}")
return int(total_fs / time_step_fs)

View File

@@ -3,8 +3,8 @@ import os
import shutil
import logging
import subprocess
from src.utils import load_yaml, run_cmd_with_log
import re
from src.utils import load_yaml, run_cmd_with_log, parse_time_to_steps
from src.machine import MachineManager
from src.state import StateTracker # 新增
from src.steps import MDStep, SelectStep, SCFStep, TrainStep
@@ -206,20 +206,22 @@ class Workflow:
task_id_select = f"{iter_name}.01.select"
if not self.tracker.is_done(task_id_select):
self.logger.info(f"=== Step: 01.select ({step_conf.get('method', 'distance')}) ===")
method = step_conf.get('method', 'distance')
self.logger.info(f"=== Step: 01.select ({method}) ===")
os.makedirs(step_dir, exist_ok=True)
# [新增] 提前建立 output 目录,用于存放 select.csv 和 select.png
output_dir = os.path.join(iter_path, "05.output")
os.makedirs(output_dir, exist_ok=True)
# 1. 准备 dump.xyz (软链接)
# 优先使用上一这步记录的路径,如果没有则尝试推断
dump_src = getattr(self, 'last_dump_path', None)
if not dump_src:
# 尝试推断iter_XX/00.md/production/dump.xyz
dump_src = os.path.join(iter_path, "00.md", "production", "dump.xyz")
if os.path.exists(dump_src):
dst_dump = os.path.join(step_dir, "dump.xyz")
if os.path.exists(dst_dump): os.remove(dst_dump)
# 使用绝对路径建立软链
os.symlink(os.path.abspath(dump_src), dst_dump)
else:
self.logger.error(f"Source dump.xyz not found: {dump_src}")
@@ -229,9 +231,7 @@ class Workflow:
shutil.copy(self.current_nep_pot, os.path.join(step_dir, "nep.txt"))
# 3. 准备 train.xyz
# 逻辑:第一轮用 model.xyz之后用 self.current_train_set
if iter_id == 0:
# 第一轮:找同轮次 00.md 下的 model.xyz
model_xyz_src = os.path.join(iter_path, "00.md", "model.xyz")
if os.path.exists(model_xyz_src):
shutil.copy(model_xyz_src, os.path.join(step_dir, "train.xyz"))
@@ -239,7 +239,6 @@ class Workflow:
self.logger.error("model.xyz not found for initial train.xyz")
return
else:
# 后续轮次
if os.path.exists(self.current_train_set):
shutil.copy(self.current_train_set, os.path.join(step_dir, "train.xyz"))
else:
@@ -247,8 +246,7 @@ class Workflow:
return
# 4. 执行筛选逻辑
method = step_conf.get('method', 'distance')
params = step_conf.get('params', [0.01, 60, 120]) # [threshold, min, max]
params = step_conf.get('params', [0.01, 60, 120])
kit_path = self.machine.config['paths'].get('gpumdkit', 'gpumdkit.sh')
# === 分支 A: 按距离筛选 (Loop) ===
@@ -259,68 +257,95 @@ class Workflow:
max_attempts = 15
success = False
# [新增] 用于记录筛选历史
selection_log = [] # List of dicts
self.logger.info(f"Targeting {target_min}-{target_max} structures. Initial thr={threshold}")
for attempt in range(max_attempts):
self.logger.info(f"--- Attempt {attempt + 1}: Threshold {threshold:.4f} ---")
current_attempt = attempt + 1
self.logger.info(f"--- Attempt {current_attempt}: Threshold {threshold:.4f} ---")
# 构造输入: 203 -> filenames -> 1 (distance) -> threshold
# 格式: "203\ndump.xyz train.xyz nep.txt\n1\n0.01"
input_str = f"203\ndump.xyz train.xyz nep.txt\n1\n{threshold}"
# 运行 gpumdkit (输出重定向到 append 模式的日志)
if not run_cmd_with_log(kit_path, step_dir, "select_exec.log", input_str=input_str):
self.logger.error("gpumdkit execution failed.")
break # 致命错误直接退出
break
# 检查 selected.xyz 数量
if not os.path.exists(os.path.join(step_dir, "selected.xyz")):
self.logger.warning("selected.xyz not found. Retrying...")
continue
try:
# grep -c "Lat" selected.xyz
count_out = subprocess.check_output('grep -c "Lat" selected.xyz', shell=True,
cwd=step_dir)
count = int(count_out.decode().strip())
except:
count = 0
# 检查 selected.xyz 数量
count = 0
if os.path.exists(os.path.join(step_dir, "selected.xyz")):
try:
count_out = subprocess.check_output('grep -c "Lat" selected.xyz', shell=True,
cwd=step_dir)
count = int(count_out.decode().strip())
except:
count = 0
self.logger.info(f"Selected count: {count}")
# [新增] 记录本次结果
selection_log.append({
"attempt": current_attempt,
"threshold": threshold,
"count": count,
"result": "Success" if target_min <= count <= target_max else "Fail"
})
# 判断逻辑
if target_min <= count <= target_max:
self.logger.info(f"Success! {count} structures selected.")
success = True
break
elif count < target_min:
# 选少了 -> 阈值太严 -> 降低阈值? (通常距离阈值越低,选的越多?)
# 假设Threshold 代表“最小允许距离”。距离 > Thr 才选。
# 那么 Thr 越小,条件越松,选的越多。
self.logger.info("Too few. Decreasing threshold (Loosening).")
threshold -= step_size
if threshold < 0: threshold = 0.0001
else: # count > target_max
# 选多了 -> 提高阈值
else:
self.logger.info("Too many. Increasing threshold (Tightening).")
threshold += step_size
# [新增] 循环结束后,写入 CSV 日志
csv_path = os.path.join(output_dir, "select_log.csv")
try:
with open(csv_path, 'w') as f:
f.write("Attempt,Threshold,Count,Result\n")
for entry in selection_log:
f.write(
f"{entry['attempt']},{entry['threshold']:.4f},{entry['count']},{entry['result']}\n")
self.logger.info(f"Selection log saved to {csv_path}")
except Exception as e:
self.logger.error(f"Failed to write select_log.csv: {e}")
# [新增] 复制 select.png (如果生成了的话)
src_png = os.path.join(step_dir, "select.png")
if os.path.exists(src_png):
shutil.copy(src_png, os.path.join(output_dir, "select.png"))
self.logger.info("Archived select.png to output.")
if success:
self.tracker.mark_done(task_id_select)
else:
self.logger.warning(
"Failed to reach target count within max attempts. Using last result.")
# 这里可以选择是否标记完成,或者报错暂停。目前标记完成以便继续。
self.logger.warning("Failed to reach target count. Using last result.")
self.tracker.mark_done(task_id_select)
# === 分支 B: 按个数/随机筛选 (One-shot) ===
# === 分支 B: 按个数/随机筛选 ===
elif method == "random":
min_n, max_n = params[1], params[2]
# 构造输入: 203 -> filenames -> 2 (number) -> min max
input_str = f"203\ndump.xyz train.xyz nep.txt\n2\n{min_n} {max_n}"
self.logger.info(f"Random selection: {min_n}-{max_n}")
if run_cmd_with_log(kit_path, step_dir, "select_exec.log", input_str=input_str):
# 随机筛选没有迭代过程,但也可以记录一下
with open(os.path.join(output_dir, "select_log.csv"), 'w') as f:
f.write("Method,Min,Max,Result\n")
f.write(f"Random,{min_n},{max_n},Executed\n")
# 同样尝试复制图片
src_png = os.path.join(step_dir, "select.png")
if os.path.exists(src_png):
shutil.copy(src_png, os.path.join(output_dir, "select.png"))
self.tracker.mark_done(task_id_select)
else:
self.logger.error("Random selection failed.")
@@ -561,3 +586,170 @@ class Workflow:
# 恢复变量状态
self.current_nep_pot = os.path.join(step_dir, "nep.txt")
self.current_train_set = os.path.join(step_dir, "train.xyz")
# ==========================
# Step: 04.predict (Conductivity)
# ==========================
elif step_name == "04.predict":
step_dir = os.path.join(iter_path, "04.predict")
task_id_predict = f"{iter_name}.04.predict"
if not self.tracker.is_done(task_id_predict):
self.logger.info("=== Step: 04.predict (Arrhenius) ===")
os.makedirs(step_dir, exist_ok=True)
# 准备输出目录
output_dir = os.path.join(iter_path, "05.output")
os.makedirs(output_dir, exist_ok=True)
# 1. 准备基础文件 (model.xyz, nep.txt)
# 注意:预测通常使用最新的 nep.txt 和 原始/最新 model.xyz
# 这里使用 current_nep_pot 和 初始 model.xyz (保证一致性)
model_src = os.path.join(iter_path, "00.md", "model.xyz") # 假设 00.md 里有
if not os.path.exists(model_src):
# 如果没有,尝试从 data 目录拿
model_src = os.path.join(self.data_dir, self.param['files']['poscar'])
# 这里简化处理,最好确保 00.md 产生了 model.xyz
# 2. 遍历温度点执行模拟
conditions = step_conf.get('conditions', [])
if not conditions:
self.logger.error("No conditions defined for 04.predict")
continue
# 发送开始通知
self.notifier.send("Predict Start", f"Starting {len(conditions)} tasks in {iter_name}", 5)
for cond in conditions:
temp = cond['T']
time_str = cond['time']
steps = parse_time_to_steps(time_str)
# [新增] 计算 MSD Window
# 公式: 10 * window * 20 = steps => window = steps / 200
msd_window = int(steps / 200)
sub_dir_name = f"{temp}K"
sub_work_dir = os.path.join(step_dir, sub_dir_name)
os.makedirs(sub_work_dir, exist_ok=True)
self.logger.info(
f"-> Running Prediction: {temp}K, {time_str} ({steps} steps, MSD window={msd_window})")
# 分发文件
if os.path.exists(model_src):
shutil.copy(model_src, os.path.join(sub_work_dir, "model.xyz"))
shutil.copy(self.current_nep_pot, os.path.join(sub_work_dir, "nep.txt"))
# 生成 run.in
template_path = os.path.join(self.template_dir, "04.predict", "run.in")
if os.path.exists(template_path):
with open(template_path, 'r') as f:
content = f.read()
# [修改] 替换参数,增加 MSD_WINDOW
content = content.replace("{T}", str(temp))
content = content.replace("{STEPS}", str(steps))
content = content.replace("{MSD_WINDOW}", str(msd_window))
with open(os.path.join(sub_work_dir, "run.in"), 'w') as f:
f.write(content)
else:
self.logger.error(f"Template not found: {template_path}")
continue
# 执行 GPUMD
# 这里不使用 nohup 模式,因为是串行跑,直接记录日志即可
if not run_cmd_with_log("gpumd", sub_work_dir, "predict.log"):
self.logger.error(f"Prediction failed at {temp}K")
# 也可以选择 continue 或者 return
# 3. 后处理分析 (gpumdkit.sh -plt sigma)
self.logger.info("Running Analysis (gpumdkit -plt sigma)...")
kit_path = self.machine.config['paths'].get('gpumdkit', 'gpumdkit.sh')
# 捕获输出用于解析
analysis_log = os.path.join(step_dir, "sigma_analysis.log")
cmd_analyze = f"{kit_path} -plt sigma"
# 我们需要捕获 stdout 内容
process = subprocess.Popen(
cmd_analyze, shell=True, cwd=step_dir,
stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
)
stdout, _ = process.communicate()
# 将输出写入日志文件备份
with open(analysis_log, 'w') as f:
f.write(stdout)
# 4. 解析输出并生成报告
# 目标格式: CSV
csv_data = []
ea_val = "N/A"
sigma_300k = "N/A"
lines = stdout.split('\n')
in_table = False
for line in lines:
line = line.strip()
# 解析 Ea
if "Ea:" in line:
# 格式: 04.predict, Ea: 0.305 eV
match = re.search(r"Ea:\s+([\d\.]+)\s+eV", line)
if match: ea_val = match.group(1)
# 解析 300K Sigma
if "at 300K" in line:
# 格式: at 300K, 04.predict: Sigma = 1.816e-03 S/cm
match = re.search(r"Sigma\s*=\s*([\d\.eE\+\-]+)", line)
if match: sigma_300k = match.group(1)
# 解析表格
if "----------------" in line:
in_table = not in_table # 切换状态
continue
if in_table:
# T (K) Sigma (S/cm) Sigma·T (K·S/cm)
# 425 3.656e-02 1.554e+01
parts = line.split()
if len(parts) >= 3 and parts[0].isdigit():
csv_data.append({
"T(K)": parts[0],
"Sigma(S/cm)": parts[1],
"Sigma*T": parts[2]
})
# 写入 CSV 到 output
report_path = os.path.join(output_dir, "conductivity_report.csv")
with open(report_path, 'w') as f:
# 写头部汇总信息
f.write(f"# Extracted from gpumdkit output\n")
f.write(f"# Ea (eV):,{ea_val}\n")
f.write(f"# Extrapolated Sigma@300K (S/cm):,{sigma_300k}\n")
f.write("\n")
# 写表格数据
f.write("Temperature(K),Sigma(S/cm),Sigma*T(K*S/cm)\n")
for row in csv_data:
f.write(f"{row['T(K)']},{row['Sigma(S/cm)']},{row['Sigma*T']}\n")
self.logger.info(f"Report saved to {report_path}")
self.notifier.send("Predict Done", f"Ea: {ea_val} eV, Sigma@300K: {sigma_300k}", 5)
# 5. 归档 Arrhenius.png
src_png = os.path.join(step_dir, "Arrhenius.png")
if os.path.exists(src_png):
shutil.copy(src_png, os.path.join(output_dir, "Arrhenius.png"))
self.logger.info("Archived Arrhenius.png")
else:
self.logger.warning("Arrhenius.png not found.")
self.tracker.mark_done(task_id_predict)
else:
self.logger.info("Skipping Predict (Already Done).")

View File

@@ -0,0 +1,19 @@
potential ./nep.txt
time_step 1
# Stage 1: Heating (NPT) - 30 ps
velocity 300
ensemble npt_scr 300 {T} 100 0 0 0 0 0 0 50 50 50 5 5 5 1000
run 30000
# Stage 2: Equilibration (NPT) - 60 ps
ensemble npt_scr {T} {T} 100 0 0 0 0 0 0 50 50 50 5 5 5 1000
run 60000
# Stage 3: Production (NVT)
ensemble nvt_nhc {T} {T} 100
# MSD 设置: 10 * window * 20 = steps
compute_msd 10 {MSD_WINDOW} group 0 0
dump_thermo 1000
dump_exyz 5000
run {STEPS}