nep框架重构 03.train
This commit is contained in:
33
src/utils.py
33
src/utils.py
@@ -82,4 +82,35 @@ def run_cmd_with_log(cmd, cwd, log_file="exec.log", input_str=None):
|
||||
return process.returncode == 0
|
||||
except Exception as e:
|
||||
f.write(f"\n>>> Exception: {str(e)}\n")
|
||||
return False
|
||||
return False
|
||||
|
||||
|
||||
def parse_time_to_steps(time_str, time_step_fs=1.0):
|
||||
"""
|
||||
解析时间字符串 (e.g., '5ns', '10ps', '50000') 为模拟步数
|
||||
"""
|
||||
import re
|
||||
|
||||
# 如果纯数字,直接返回整数
|
||||
if str(time_str).isdigit():
|
||||
return int(time_str)
|
||||
|
||||
# 正则匹配数字和单位
|
||||
match = re.match(r"([\d\.]+)\s*([a-zA-Z]+)", str(time_str))
|
||||
if not match:
|
||||
raise ValueError(f"Unknown time format: {time_str}")
|
||||
|
||||
value = float(match.group(1))
|
||||
unit = match.group(2).lower()
|
||||
|
||||
# 基础单位是 fs
|
||||
if unit == 'fs':
|
||||
total_fs = value
|
||||
elif unit == 'ps':
|
||||
total_fs = value * 1000.0
|
||||
elif unit == 'ns':
|
||||
total_fs = value * 1000.0 * 1000.0
|
||||
else:
|
||||
raise ValueError(f"Unsupported time unit: {unit}")
|
||||
|
||||
return int(total_fs / time_step_fs)
|
||||
268
src/workflow.py
268
src/workflow.py
@@ -3,8 +3,8 @@ import os
|
||||
import shutil
|
||||
import logging
|
||||
import subprocess
|
||||
|
||||
from src.utils import load_yaml, run_cmd_with_log
|
||||
import re
|
||||
from src.utils import load_yaml, run_cmd_with_log, parse_time_to_steps
|
||||
from src.machine import MachineManager
|
||||
from src.state import StateTracker # 新增
|
||||
from src.steps import MDStep, SelectStep, SCFStep, TrainStep
|
||||
@@ -206,20 +206,22 @@ class Workflow:
|
||||
task_id_select = f"{iter_name}.01.select"
|
||||
|
||||
if not self.tracker.is_done(task_id_select):
|
||||
self.logger.info(f"=== Step: 01.select ({step_conf.get('method', 'distance')}) ===")
|
||||
method = step_conf.get('method', 'distance')
|
||||
self.logger.info(f"=== Step: 01.select ({method}) ===")
|
||||
os.makedirs(step_dir, exist_ok=True)
|
||||
|
||||
# [新增] 提前建立 output 目录,用于存放 select.csv 和 select.png
|
||||
output_dir = os.path.join(iter_path, "05.output")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 1. 准备 dump.xyz (软链接)
|
||||
# 优先使用上一这步记录的路径,如果没有则尝试推断
|
||||
dump_src = getattr(self, 'last_dump_path', None)
|
||||
if not dump_src:
|
||||
# 尝试推断:iter_XX/00.md/production/dump.xyz
|
||||
dump_src = os.path.join(iter_path, "00.md", "production", "dump.xyz")
|
||||
|
||||
if os.path.exists(dump_src):
|
||||
dst_dump = os.path.join(step_dir, "dump.xyz")
|
||||
if os.path.exists(dst_dump): os.remove(dst_dump)
|
||||
# 使用绝对路径建立软链
|
||||
os.symlink(os.path.abspath(dump_src), dst_dump)
|
||||
else:
|
||||
self.logger.error(f"Source dump.xyz not found: {dump_src}")
|
||||
@@ -229,9 +231,7 @@ class Workflow:
|
||||
shutil.copy(self.current_nep_pot, os.path.join(step_dir, "nep.txt"))
|
||||
|
||||
# 3. 准备 train.xyz
|
||||
# 逻辑:第一轮用 model.xyz,之后用 self.current_train_set
|
||||
if iter_id == 0:
|
||||
# 第一轮:找同轮次 00.md 下的 model.xyz
|
||||
model_xyz_src = os.path.join(iter_path, "00.md", "model.xyz")
|
||||
if os.path.exists(model_xyz_src):
|
||||
shutil.copy(model_xyz_src, os.path.join(step_dir, "train.xyz"))
|
||||
@@ -239,7 +239,6 @@ class Workflow:
|
||||
self.logger.error("model.xyz not found for initial train.xyz")
|
||||
return
|
||||
else:
|
||||
# 后续轮次
|
||||
if os.path.exists(self.current_train_set):
|
||||
shutil.copy(self.current_train_set, os.path.join(step_dir, "train.xyz"))
|
||||
else:
|
||||
@@ -247,8 +246,7 @@ class Workflow:
|
||||
return
|
||||
|
||||
# 4. 执行筛选逻辑
|
||||
method = step_conf.get('method', 'distance')
|
||||
params = step_conf.get('params', [0.01, 60, 120]) # [threshold, min, max]
|
||||
params = step_conf.get('params', [0.01, 60, 120])
|
||||
kit_path = self.machine.config['paths'].get('gpumdkit', 'gpumdkit.sh')
|
||||
|
||||
# === 分支 A: 按距离筛选 (Loop) ===
|
||||
@@ -259,68 +257,95 @@ class Workflow:
|
||||
max_attempts = 15
|
||||
success = False
|
||||
|
||||
# [新增] 用于记录筛选历史
|
||||
selection_log = [] # List of dicts
|
||||
|
||||
self.logger.info(f"Targeting {target_min}-{target_max} structures. Initial thr={threshold}")
|
||||
|
||||
for attempt in range(max_attempts):
|
||||
self.logger.info(f"--- Attempt {attempt + 1}: Threshold {threshold:.4f} ---")
|
||||
current_attempt = attempt + 1
|
||||
self.logger.info(f"--- Attempt {current_attempt}: Threshold {threshold:.4f} ---")
|
||||
|
||||
# 构造输入: 203 -> filenames -> 1 (distance) -> threshold
|
||||
# 格式: "203\ndump.xyz train.xyz nep.txt\n1\n0.01"
|
||||
input_str = f"203\ndump.xyz train.xyz nep.txt\n1\n{threshold}"
|
||||
|
||||
# 运行 gpumdkit (输出重定向到 append 模式的日志)
|
||||
if not run_cmd_with_log(kit_path, step_dir, "select_exec.log", input_str=input_str):
|
||||
self.logger.error("gpumdkit execution failed.")
|
||||
break # 致命错误直接退出
|
||||
break
|
||||
|
||||
# 检查 selected.xyz 数量
|
||||
if not os.path.exists(os.path.join(step_dir, "selected.xyz")):
|
||||
self.logger.warning("selected.xyz not found. Retrying...")
|
||||
continue
|
||||
|
||||
try:
|
||||
# grep -c "Lat" selected.xyz
|
||||
count_out = subprocess.check_output('grep -c "Lat" selected.xyz', shell=True,
|
||||
cwd=step_dir)
|
||||
count = int(count_out.decode().strip())
|
||||
except:
|
||||
count = 0
|
||||
# 检查 selected.xyz 数量
|
||||
count = 0
|
||||
if os.path.exists(os.path.join(step_dir, "selected.xyz")):
|
||||
try:
|
||||
count_out = subprocess.check_output('grep -c "Lat" selected.xyz', shell=True,
|
||||
cwd=step_dir)
|
||||
count = int(count_out.decode().strip())
|
||||
except:
|
||||
count = 0
|
||||
|
||||
self.logger.info(f"Selected count: {count}")
|
||||
|
||||
# [新增] 记录本次结果
|
||||
selection_log.append({
|
||||
"attempt": current_attempt,
|
||||
"threshold": threshold,
|
||||
"count": count,
|
||||
"result": "Success" if target_min <= count <= target_max else "Fail"
|
||||
})
|
||||
|
||||
# 判断逻辑
|
||||
if target_min <= count <= target_max:
|
||||
self.logger.info(f"Success! {count} structures selected.")
|
||||
success = True
|
||||
break
|
||||
elif count < target_min:
|
||||
# 选少了 -> 阈值太严 -> 降低阈值? (通常距离阈值越低,选的越多?)
|
||||
# 假设:Threshold 代表“最小允许距离”。距离 > Thr 才选。
|
||||
# 那么 Thr 越小,条件越松,选的越多。
|
||||
self.logger.info("Too few. Decreasing threshold (Loosening).")
|
||||
threshold -= step_size
|
||||
if threshold < 0: threshold = 0.0001
|
||||
else: # count > target_max
|
||||
# 选多了 -> 提高阈值
|
||||
else:
|
||||
self.logger.info("Too many. Increasing threshold (Tightening).")
|
||||
threshold += step_size
|
||||
|
||||
# [新增] 循环结束后,写入 CSV 日志
|
||||
csv_path = os.path.join(output_dir, "select_log.csv")
|
||||
try:
|
||||
with open(csv_path, 'w') as f:
|
||||
f.write("Attempt,Threshold,Count,Result\n")
|
||||
for entry in selection_log:
|
||||
f.write(
|
||||
f"{entry['attempt']},{entry['threshold']:.4f},{entry['count']},{entry['result']}\n")
|
||||
self.logger.info(f"Selection log saved to {csv_path}")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to write select_log.csv: {e}")
|
||||
|
||||
# [新增] 复制 select.png (如果生成了的话)
|
||||
src_png = os.path.join(step_dir, "select.png")
|
||||
if os.path.exists(src_png):
|
||||
shutil.copy(src_png, os.path.join(output_dir, "select.png"))
|
||||
self.logger.info("Archived select.png to output.")
|
||||
|
||||
if success:
|
||||
self.tracker.mark_done(task_id_select)
|
||||
else:
|
||||
self.logger.warning(
|
||||
"Failed to reach target count within max attempts. Using last result.")
|
||||
# 这里可以选择是否标记完成,或者报错暂停。目前标记完成以便继续。
|
||||
self.logger.warning("Failed to reach target count. Using last result.")
|
||||
self.tracker.mark_done(task_id_select)
|
||||
|
||||
# === 分支 B: 按个数/随机筛选 (One-shot) ===
|
||||
# === 分支 B: 按个数/随机筛选 ===
|
||||
elif method == "random":
|
||||
min_n, max_n = params[1], params[2]
|
||||
# 构造输入: 203 -> filenames -> 2 (number) -> min max
|
||||
input_str = f"203\ndump.xyz train.xyz nep.txt\n2\n{min_n} {max_n}"
|
||||
|
||||
self.logger.info(f"Random selection: {min_n}-{max_n}")
|
||||
if run_cmd_with_log(kit_path, step_dir, "select_exec.log", input_str=input_str):
|
||||
# 随机筛选没有迭代过程,但也可以记录一下
|
||||
with open(os.path.join(output_dir, "select_log.csv"), 'w') as f:
|
||||
f.write("Method,Min,Max,Result\n")
|
||||
f.write(f"Random,{min_n},{max_n},Executed\n")
|
||||
|
||||
# 同样尝试复制图片
|
||||
src_png = os.path.join(step_dir, "select.png")
|
||||
if os.path.exists(src_png):
|
||||
shutil.copy(src_png, os.path.join(output_dir, "select.png"))
|
||||
|
||||
self.tracker.mark_done(task_id_select)
|
||||
else:
|
||||
self.logger.error("Random selection failed.")
|
||||
@@ -561,3 +586,170 @@ class Workflow:
|
||||
# 恢复变量状态
|
||||
self.current_nep_pot = os.path.join(step_dir, "nep.txt")
|
||||
self.current_train_set = os.path.join(step_dir, "train.xyz")
|
||||
|
||||
|
||||
|
||||
|
||||
# ==========================
|
||||
# Step: 04.predict (Conductivity)
|
||||
# ==========================
|
||||
elif step_name == "04.predict":
|
||||
step_dir = os.path.join(iter_path, "04.predict")
|
||||
task_id_predict = f"{iter_name}.04.predict"
|
||||
|
||||
if not self.tracker.is_done(task_id_predict):
|
||||
self.logger.info("=== Step: 04.predict (Arrhenius) ===")
|
||||
os.makedirs(step_dir, exist_ok=True)
|
||||
|
||||
# 准备输出目录
|
||||
output_dir = os.path.join(iter_path, "05.output")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 1. 准备基础文件 (model.xyz, nep.txt)
|
||||
# 注意:预测通常使用最新的 nep.txt 和 原始/最新 model.xyz
|
||||
# 这里使用 current_nep_pot 和 初始 model.xyz (保证一致性)
|
||||
model_src = os.path.join(iter_path, "00.md", "model.xyz") # 假设 00.md 里有
|
||||
if not os.path.exists(model_src):
|
||||
# 如果没有,尝试从 data 目录拿
|
||||
model_src = os.path.join(self.data_dir, self.param['files']['poscar'])
|
||||
# 这里简化处理,最好确保 00.md 产生了 model.xyz
|
||||
|
||||
# 2. 遍历温度点执行模拟
|
||||
conditions = step_conf.get('conditions', [])
|
||||
if not conditions:
|
||||
self.logger.error("No conditions defined for 04.predict")
|
||||
continue
|
||||
|
||||
# 发送开始通知
|
||||
self.notifier.send("Predict Start", f"Starting {len(conditions)} tasks in {iter_name}", 5)
|
||||
|
||||
for cond in conditions:
|
||||
temp = cond['T']
|
||||
time_str = cond['time']
|
||||
steps = parse_time_to_steps(time_str)
|
||||
|
||||
# [新增] 计算 MSD Window
|
||||
# 公式: 10 * window * 20 = steps => window = steps / 200
|
||||
msd_window = int(steps / 200)
|
||||
|
||||
sub_dir_name = f"{temp}K"
|
||||
sub_work_dir = os.path.join(step_dir, sub_dir_name)
|
||||
os.makedirs(sub_work_dir, exist_ok=True)
|
||||
|
||||
self.logger.info(
|
||||
f"-> Running Prediction: {temp}K, {time_str} ({steps} steps, MSD window={msd_window})")
|
||||
|
||||
# 分发文件
|
||||
if os.path.exists(model_src):
|
||||
shutil.copy(model_src, os.path.join(sub_work_dir, "model.xyz"))
|
||||
shutil.copy(self.current_nep_pot, os.path.join(sub_work_dir, "nep.txt"))
|
||||
|
||||
# 生成 run.in
|
||||
template_path = os.path.join(self.template_dir, "04.predict", "run.in")
|
||||
if os.path.exists(template_path):
|
||||
with open(template_path, 'r') as f:
|
||||
content = f.read()
|
||||
|
||||
# [修改] 替换参数,增加 MSD_WINDOW
|
||||
content = content.replace("{T}", str(temp))
|
||||
content = content.replace("{STEPS}", str(steps))
|
||||
content = content.replace("{MSD_WINDOW}", str(msd_window))
|
||||
|
||||
with open(os.path.join(sub_work_dir, "run.in"), 'w') as f:
|
||||
f.write(content)
|
||||
else:
|
||||
self.logger.error(f"Template not found: {template_path}")
|
||||
continue
|
||||
|
||||
|
||||
# 执行 GPUMD
|
||||
# 这里不使用 nohup 模式,因为是串行跑,直接记录日志即可
|
||||
if not run_cmd_with_log("gpumd", sub_work_dir, "predict.log"):
|
||||
self.logger.error(f"Prediction failed at {temp}K")
|
||||
# 也可以选择 continue 或者 return
|
||||
|
||||
# 3. 后处理分析 (gpumdkit.sh -plt sigma)
|
||||
self.logger.info("Running Analysis (gpumdkit -plt sigma)...")
|
||||
kit_path = self.machine.config['paths'].get('gpumdkit', 'gpumdkit.sh')
|
||||
|
||||
# 捕获输出用于解析
|
||||
analysis_log = os.path.join(step_dir, "sigma_analysis.log")
|
||||
cmd_analyze = f"{kit_path} -plt sigma"
|
||||
|
||||
# 我们需要捕获 stdout 内容
|
||||
process = subprocess.Popen(
|
||||
cmd_analyze, shell=True, cwd=step_dir,
|
||||
stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
|
||||
)
|
||||
stdout, _ = process.communicate()
|
||||
|
||||
# 将输出写入日志文件备份
|
||||
with open(analysis_log, 'w') as f:
|
||||
f.write(stdout)
|
||||
|
||||
# 4. 解析输出并生成报告
|
||||
# 目标格式: CSV
|
||||
csv_data = []
|
||||
ea_val = "N/A"
|
||||
sigma_300k = "N/A"
|
||||
|
||||
lines = stdout.split('\n')
|
||||
in_table = False
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
# 解析 Ea
|
||||
if "Ea:" in line:
|
||||
# 格式: 04.predict, Ea: 0.305 eV
|
||||
match = re.search(r"Ea:\s+([\d\.]+)\s+eV", line)
|
||||
if match: ea_val = match.group(1)
|
||||
|
||||
# 解析 300K Sigma
|
||||
if "at 300K" in line:
|
||||
# 格式: at 300K, 04.predict: Sigma = 1.816e-03 S/cm
|
||||
match = re.search(r"Sigma\s*=\s*([\d\.eE\+\-]+)", line)
|
||||
if match: sigma_300k = match.group(1)
|
||||
|
||||
# 解析表格
|
||||
if "----------------" in line:
|
||||
in_table = not in_table # 切换状态
|
||||
continue
|
||||
|
||||
if in_table:
|
||||
# T (K) Sigma (S/cm) Sigma·T (K·S/cm)
|
||||
# 425 3.656e-02 1.554e+01
|
||||
parts = line.split()
|
||||
if len(parts) >= 3 and parts[0].isdigit():
|
||||
csv_data.append({
|
||||
"T(K)": parts[0],
|
||||
"Sigma(S/cm)": parts[1],
|
||||
"Sigma*T": parts[2]
|
||||
})
|
||||
|
||||
# 写入 CSV 到 output
|
||||
report_path = os.path.join(output_dir, "conductivity_report.csv")
|
||||
with open(report_path, 'w') as f:
|
||||
# 写头部汇总信息
|
||||
f.write(f"# Extracted from gpumdkit output\n")
|
||||
f.write(f"# Ea (eV):,{ea_val}\n")
|
||||
f.write(f"# Extrapolated Sigma@300K (S/cm):,{sigma_300k}\n")
|
||||
f.write("\n")
|
||||
# 写表格数据
|
||||
f.write("Temperature(K),Sigma(S/cm),Sigma*T(K*S/cm)\n")
|
||||
for row in csv_data:
|
||||
f.write(f"{row['T(K)']},{row['Sigma(S/cm)']},{row['Sigma*T']}\n")
|
||||
|
||||
self.logger.info(f"Report saved to {report_path}")
|
||||
self.notifier.send("Predict Done", f"Ea: {ea_val} eV, Sigma@300K: {sigma_300k}", 5)
|
||||
|
||||
# 5. 归档 Arrhenius.png
|
||||
src_png = os.path.join(step_dir, "Arrhenius.png")
|
||||
if os.path.exists(src_png):
|
||||
shutil.copy(src_png, os.path.join(output_dir, "Arrhenius.png"))
|
||||
self.logger.info("Archived Arrhenius.png")
|
||||
else:
|
||||
self.logger.warning("Arrhenius.png not found.")
|
||||
|
||||
self.tracker.mark_done(task_id_predict)
|
||||
else:
|
||||
self.logger.info("Skipping Predict (Already Done).")
|
||||
Reference in New Issue
Block a user