Files
NEP-auto/nep_auto/modules/base_module.py
2025-12-08 17:48:03 +08:00

77 lines
2.6 KiB
Python

import os
import shutil
import logging
from pathlib import Path
from nep_auto.utils.runner import CommandRunner
class BaseModule:
def __init__(self, driver, iter_id):
"""
:param driver: NEPDriver 实例,包含所有配置
:param iter_id: 当前轮次 (int)
"""
self.driver = driver
self.config_sys = driver.config_sys
self.config_param = driver.config_param
self.machine_config = driver.config_machine['systems'][driver.config_machine['current_system']]
self.iter_id = iter_id
self.iter_name = f"iter_{iter_id:03d}"
self.logger = logging.getLogger("NEP_Auto")
# 初始化运行器
self.runner = CommandRunner(self.machine_config)
# 定义路径
self.root = Path(driver.root) / "workspace"
self.iter_dir = self.root / self.iter_name
self.output_dir = self.iter_dir / "05.output" # 公共输出区
def get_work_dir(self):
"""需由子类实现:返回当前模块的具体工作目录"""
raise NotImplementedError
def initialize(self):
"""通用初始化:创建目录,复制通用文件"""
work_dir = self.get_work_dir()
if not work_dir.exists():
work_dir.mkdir(parents=True, exist_ok=True)
self.logger.debug(f"📁 Created dir: {work_dir}")
# 确保公共输出目录存在
if not self.output_dir.exists():
self.output_dir.mkdir(parents=True, exist_ok=True)
def run(self):
"""核心逻辑入口,子类必须实现"""
raise NotImplementedError
def check_done(self):
"""检查任务是否完成,子类必须实现"""
raise NotImplementedError
# --- 通用工具方法 ---
def copy_template(self, template_name, target_name=None):
"""从 template 目录复制文件"""
if target_name is None:
target_name = template_name
# 根据模块类型寻找模板目录 (需要在子类定义 self.template_subdir)
src = Path("template") / getattr(self, "template_subdir", "common") / template_name
dst = self.get_work_dir() / target_name
if src.exists():
shutil.copy(src, dst)
# self.logger.debug(f"📄 Copied {template_name} -> {dst}")
else:
self.logger.warning(f"⚠️ Template not found: {src}")
def link_file(self, src_path, dst_name):
"""创建软链接"""
src = Path(src_path).resolve()
dst = self.get_work_dir() / dst_name
if dst.exists():
dst.unlink()
os.symlink(src, dst)