nep框架重构
This commit is contained in:
129
src/machine.py
Normal file
129
src/machine.py
Normal file
@@ -0,0 +1,129 @@
|
||||
# src/machine.py
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
import logging
|
||||
import shutil
|
||||
|
||||
|
||||
class MachineManager:
|
||||
def __init__(self, machine_config_path):
|
||||
from src.utils import load_yaml
|
||||
self.config = load_yaml(machine_config_path)
|
||||
self.root_dir = self.config.get('root_dir', os.getcwd())
|
||||
self.script_dir = os.path.join(self.root_dir, self.config.get('script_dir', 'config/scripts'))
|
||||
|
||||
logging.info(f"MachineManager initialized. Script dir: {self.script_dir}")
|
||||
|
||||
def execute(self, executor_name, work_dir):
|
||||
"""
|
||||
统一执行入口
|
||||
:param executor_name: machine.yaml 中定义的 key (如 gpumd, vasp_cpu)
|
||||
:param work_dir: 任务执行的工作目录
|
||||
"""
|
||||
if executor_name not in self.config['executors']:
|
||||
logging.error(f"Executor '{executor_name}' not defined in machine.yaml")
|
||||
return False
|
||||
|
||||
exec_conf = self.config['executors'][executor_name]
|
||||
exec_type = exec_conf.get('type', 'local')
|
||||
|
||||
# 确保工作目录存在
|
||||
os.makedirs(work_dir, exist_ok=True)
|
||||
|
||||
logging.info(f"--- Task: {executor_name} | Type: {exec_type} ---")
|
||||
logging.info(f"Working Dir: {work_dir}")
|
||||
|
||||
if exec_type == 'local':
|
||||
return self._run_local(exec_conf, work_dir)
|
||||
elif exec_type == 'slurm':
|
||||
return self._submit_slurm(exec_conf, work_dir, executor_name)
|
||||
else:
|
||||
logging.error(f"Unknown execution type: {exec_type}")
|
||||
return False
|
||||
|
||||
def _run_local(self, conf, work_dir):
|
||||
"""本地直接执行"""
|
||||
# 1. 优先看有没有 script 脚本文件
|
||||
if 'script' in conf:
|
||||
script_name = conf['script']
|
||||
src_script = os.path.join(self.script_dir, script_name)
|
||||
|
||||
if not os.path.exists(src_script):
|
||||
logging.error(f"Script not found: {src_script}")
|
||||
return False
|
||||
|
||||
# 运行脚本: bash /path/to/script.sh
|
||||
cmd = f"bash {src_script}"
|
||||
|
||||
# 2. 如果没有脚本,看有没有 cmd 直接命令
|
||||
elif 'cmd' in conf:
|
||||
cmd = conf['cmd']
|
||||
else:
|
||||
logging.error("No 'script' or 'cmd' defined for local executor.")
|
||||
return False
|
||||
|
||||
try:
|
||||
# 切换到工作目录执行
|
||||
logging.info(f"Executing Local Command: {cmd}")
|
||||
subprocess.check_call(cmd, shell=True, cwd=work_dir)
|
||||
logging.info("Local execution success.")
|
||||
return True
|
||||
except subprocess.CalledProcessError as e:
|
||||
logging.error(f"Execution failed with error code {e.returncode}")
|
||||
return False
|
||||
|
||||
def _submit_slurm(self, conf, work_dir, job_name):
|
||||
"""生成 Slurm 脚本并提交 (模拟)"""
|
||||
script_name = conf.get('script')
|
||||
src_script = os.path.join(self.script_dir, script_name)
|
||||
|
||||
if not os.path.exists(src_script):
|
||||
logging.error(f"Script not found: {src_script}")
|
||||
return False
|
||||
|
||||
# 1. 读取用户自定义脚本内容
|
||||
with open(src_script, 'r') as f:
|
||||
user_script_content = f.read()
|
||||
|
||||
# 2. 生成提交脚本 (.sub)
|
||||
sub_file = os.path.join(work_dir, "submit.sub")
|
||||
|
||||
with open(sub_file, 'w') as f:
|
||||
f.write("#!/bin/bash\n")
|
||||
f.write(f"#SBATCH --job-name={job_name}\n")
|
||||
# 根据 yaml 自动填入 SBATCH 参数
|
||||
if 'partition' in conf: f.write(f"#SBATCH --partition={conf['partition']}\n")
|
||||
if 'nodes' in conf: f.write(f"#SBATCH --nodes={conf['nodes']}\n")
|
||||
if 'ntasks' in conf: f.write(f"#SBATCH --ntasks={conf['ntasks']}\n")
|
||||
if 'time' in conf: f.write(f"#SBATCH --time={conf['time']}\n")
|
||||
if 'gpus' in conf: f.write(f"#SBATCH --gres=gpu:{conf['gpus']}\n")
|
||||
|
||||
f.write("\n")
|
||||
f.write("cd $SLURM_SUBMIT_DIR\n")
|
||||
f.write("\n")
|
||||
f.write("# --- User Script Content ---\n")
|
||||
f.write(user_script_content)
|
||||
|
||||
logging.info(f"Generated submission script: {sub_file}")
|
||||
|
||||
# 3. 提交任务
|
||||
# 注意:这里我们做个判断,如果是在非 Slurm 环境测试,就不真正提交,只生成文件
|
||||
# 如果你想真正提交,把下面的 True 改为 False
|
||||
TEST_MODE = True
|
||||
|
||||
if TEST_MODE:
|
||||
logging.info("[TEST_MODE] Simulated 'sbatch submit.sub'. Check the .sub file.")
|
||||
return True
|
||||
else:
|
||||
try:
|
||||
# 提交并获取 Job ID
|
||||
res = subprocess.check_output(f"sbatch {sub_file}", shell=True, cwd=work_dir)
|
||||
job_id = res.decode().strip().split()[-1] # 通常输出是 Submitted batch job 123456
|
||||
logging.info(f"Job submitted. ID: {job_id}")
|
||||
|
||||
# TODO: 这里需要加入 wait_for_job(job_id) 的逻辑,我们下一阶段实现
|
||||
return True
|
||||
except subprocess.CalledProcessError as e:
|
||||
logging.error(f"Submission failed: {e}")
|
||||
return False
|
||||
Reference in New Issue
Block a user