117 lines
4.0 KiB
Python
117 lines
4.0 KiB
Python
import torch
|
||
import cv2
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
|
||
from linknet.main import LinkNet
|
||
|
||
|
||
# --- 首先,确保模型定义代码在这里 ---
|
||
# (你需要从之前的 main.py 文件中复制 LinkNet 和 DecoderBlock 类的定义)
|
||
# class DecoderBlock(nn.Module): ...
|
||
# class LinkNet(nn.Module): ...
|
||
# --- 假设模型定义代码已经复制过来了 ---
|
||
|
||
|
||
def predict_single_image(model_path, image_path, image_size=256):
|
||
"""
|
||
加载训练好的LinkNet模型,对单张新图片进行预测。
|
||
|
||
参数:
|
||
- model_path (str): 已保存的模型权重文件路径 (.pth)。
|
||
- image_path (str): 待预测的图片文件路径。
|
||
- image_size (int): 模型训练时使用的图片尺寸。
|
||
|
||
返回:
|
||
- predicted_mask (numpy.ndarray): 预测出的二值化掩码图,与原图尺寸相同,
|
||
像素值为0或255。
|
||
- overlay_image (numpy.ndarray): 将预测掩码(红色)叠加在原图上的结果图。
|
||
"""
|
||
# 1. 设备选择
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
print(f"Using device: {device}")
|
||
|
||
# 2. 加载模型结构并载入权重
|
||
# - 实例化你的模型结构
|
||
# - 使用 load_state_dict 载入已保存的权重
|
||
# - 将模型切换到评估模式 .eval()
|
||
model = LinkNet(num_classes=1)
|
||
model.load_state_dict(torch.load(model_path, map_location=device))
|
||
model.to(device)
|
||
model.eval()
|
||
|
||
# 3. 加载并预处理图片
|
||
# - 读取图片(灰度模式)
|
||
img_original = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
|
||
if img_original is None:
|
||
raise FileNotFoundError(f"Image not found at {image_path}")
|
||
|
||
original_h, original_w = img_original.shape
|
||
|
||
# - 缩放到模型需要的尺寸 (和训练时一样)
|
||
img_resized = cv2.resize(img_original, (image_size, image_size))
|
||
|
||
# - 归一化 (和训练时一样)
|
||
img_normalized = img_resized / 255.0
|
||
|
||
# - 增加批次和通道维度 (H, W) -> (1, 1, H, W)
|
||
img_tensor = torch.from_numpy(img_normalized).unsqueeze(0).unsqueeze(0).float()
|
||
|
||
# - 将Tensor发送到设备
|
||
img_tensor = img_tensor.to(device)
|
||
|
||
# 4. 模型推理
|
||
# 使用 torch.no_grad() 来关闭梯度计算,节省显存并加速
|
||
with torch.no_grad():
|
||
output = model(img_tensor)
|
||
|
||
# 5. 后处理
|
||
# - 将输出Tensor转回Numpy数组
|
||
# - 去掉批次和通道维度 (1, 1, H, W) -> (H, W)
|
||
pred_mask_resized = output.cpu().numpy()[0, 0]
|
||
|
||
# - 二值化:将概率图 (0-1) 转换为掩码图 (0或1)
|
||
# 这里的 0.5 是阈值,可以根据实际情况调整
|
||
pred_mask_binary = (pred_mask_resized > 0.5).astype(np.uint8)
|
||
|
||
# - 将掩码图恢复到原始图片的尺寸
|
||
predicted_mask = cv2.resize(pred_mask_binary, (original_w, original_h),
|
||
interpolation=cv2.INTER_NEAREST) * 255
|
||
|
||
# 6. (可选) 创建可视化叠加图
|
||
overlay_image = cv2.cvtColor(img_original, cv2.COLOR_GRAY2BGR)
|
||
overlay_image[predicted_mask == 255] = [0, 0, 255] # 在焊缝位置标记为红色 (BGR)
|
||
|
||
return predicted_mask, overlay_image
|
||
|
||
|
||
# --- 如何调用这个函数 ---
|
||
if __name__ == '__main__':
|
||
# 从 main.py 中复制模型定义代码到这里
|
||
# ...
|
||
|
||
MODEL_FILE = 'best_linknet_up_model_line2.pth'
|
||
IMAGE_TO_TEST = 'test/004/input/004.jpg' # <--- 修改为你的图片路径
|
||
|
||
try:
|
||
# 调用预测函数
|
||
final_mask, overlay = predict_single_image(MODEL_FILE, IMAGE_TO_TEST)
|
||
|
||
# 显示结果
|
||
# cv2.imshow('Original Image', cv2.imread(IMAGE_TO_TEST))
|
||
# cv2.imshow('Predicted Mask', final_mask)
|
||
# cv2.imshow('Overlay Result', overlay)
|
||
#
|
||
# # 按任意键退出
|
||
# cv2.waitKey(0)
|
||
# cv2.destroyAllWindows()
|
||
|
||
# 也可以保存结果
|
||
# cv2.imwrite('predicted_mask.png', final_mask)
|
||
cv2.imwrite('overlay_result.png', overlay)
|
||
|
||
except FileNotFoundError as e:
|
||
print(e)
|
||
except Exception as e:
|
||
print(f"An error occurred: {e}")
|