nep框架重构 03.train
This commit is contained in:
46
src/state.py
46
src/state.py
@@ -1,11 +1,15 @@
|
|||||||
# src/state.py
|
# src/state.py
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
class StateTracker:
|
class StateTracker:
|
||||||
def __init__(self, workspace_dir):
|
def __init__(self, workspace_dir):
|
||||||
self.state_file = os.path.join(workspace_dir, "workflow_status.json")
|
self.state_file = os.path.join(workspace_dir, "workflow_status.json")
|
||||||
self.completed_tasks = set()
|
self.history = [] # 用于存储有序的记录 [{'task': id, 'time': time}, ...]
|
||||||
|
self.completed_set = set() # 用于快速查找
|
||||||
self.load()
|
self.load()
|
||||||
|
|
||||||
def load(self):
|
def load(self):
|
||||||
@@ -13,19 +17,45 @@ class StateTracker:
|
|||||||
try:
|
try:
|
||||||
with open(self.state_file, 'r') as f:
|
with open(self.state_file, 'r') as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
self.completed_tasks = set(data.get("completed", []))
|
# 兼容旧版本:如果 data["completed"] 是纯字符串列表
|
||||||
except:
|
raw_list = data.get("completed", [])
|
||||||
self.completed_tasks = set()
|
if raw_list and isinstance(raw_list[0], str):
|
||||||
|
# 旧版本转换:给个假时间
|
||||||
|
self.history = [{"task": task, "time": "Unknown"} for task in raw_list]
|
||||||
|
else:
|
||||||
|
# 新版本直接读取
|
||||||
|
self.history = raw_list
|
||||||
|
|
||||||
|
# 重建集合用于快速判断 is_done
|
||||||
|
self.completed_set = {item['task'] for item in self.history}
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Failed to load state file: {e}")
|
||||||
|
self.history = []
|
||||||
|
self.completed_set = set()
|
||||||
|
|
||||||
def mark_done(self, task_id):
|
def mark_done(self, task_id):
|
||||||
"""标记任务完成并保存"""
|
"""标记任务完成并保存(带时间戳)"""
|
||||||
self.completed_tasks.add(task_id)
|
if task_id in self.completed_set:
|
||||||
|
return # 避免重复记录
|
||||||
|
|
||||||
|
# 获取当前时间
|
||||||
|
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
|
||||||
|
# 添加记录
|
||||||
|
record = {
|
||||||
|
"task": task_id,
|
||||||
|
"time": timestamp
|
||||||
|
}
|
||||||
|
self.history.append(record)
|
||||||
|
self.completed_set.add(task_id)
|
||||||
|
|
||||||
self.save()
|
self.save()
|
||||||
|
|
||||||
def is_done(self, task_id):
|
def is_done(self, task_id):
|
||||||
"""检查任务是否已完成"""
|
"""检查任务是否已完成"""
|
||||||
return task_id in self.completed_tasks
|
return task_id in self.completed_set
|
||||||
|
|
||||||
def save(self):
|
def save(self):
|
||||||
with open(self.state_file, 'w') as f:
|
with open(self.state_file, 'w') as f:
|
||||||
json.dump({"completed": list(self.completed_tasks)}, f, indent=2)
|
# indent=2 让文件更易读
|
||||||
|
json.dump({"completed": self.history}, f, indent=2, ensure_ascii=False)
|
||||||
Reference in New Issue
Block a user