Files
innovate_project/linknet/predict.py
2025-11-02 21:36:35 +08:00

117 lines
4.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import cv2
import numpy as np
import matplotlib.pyplot as plt
from linknet.main import LinkNet
# --- 首先,确保模型定义代码在这里 ---
# (你需要从之前的 main.py 文件中复制 LinkNet 和 DecoderBlock 类的定义)
# class DecoderBlock(nn.Module): ...
# class LinkNet(nn.Module): ...
# --- 假设模型定义代码已经复制过来了 ---
def predict_single_image(model_path, image_path, image_size=256):
"""
加载训练好的LinkNet模型对单张新图片进行预测。
参数:
- model_path (str): 已保存的模型权重文件路径 (.pth)。
- image_path (str): 待预测的图片文件路径。
- image_size (int): 模型训练时使用的图片尺寸。
返回:
- predicted_mask (numpy.ndarray): 预测出的二值化掩码图,与原图尺寸相同,
像素值为0或255。
- overlay_image (numpy.ndarray): 将预测掩码(红色)叠加在原图上的结果图。
"""
# 1. 设备选择
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# 2. 加载模型结构并载入权重
# - 实例化你的模型结构
# - 使用 load_state_dict 载入已保存的权重
# - 将模型切换到评估模式 .eval()
model = LinkNet(num_classes=1)
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()
# 3. 加载并预处理图片
# - 读取图片(灰度模式)
img_original = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
if img_original is None:
raise FileNotFoundError(f"Image not found at {image_path}")
original_h, original_w = img_original.shape
# - 缩放到模型需要的尺寸 (和训练时一样)
img_resized = cv2.resize(img_original, (image_size, image_size))
# - 归一化 (和训练时一样)
img_normalized = img_resized / 255.0
# - 增加批次和通道维度 (H, W) -> (1, 1, H, W)
img_tensor = torch.from_numpy(img_normalized).unsqueeze(0).unsqueeze(0).float()
# - 将Tensor发送到设备
img_tensor = img_tensor.to(device)
# 4. 模型推理
# 使用 torch.no_grad() 来关闭梯度计算,节省显存并加速
with torch.no_grad():
output = model(img_tensor)
# 5. 后处理
# - 将输出Tensor转回Numpy数组
# - 去掉批次和通道维度 (1, 1, H, W) -> (H, W)
pred_mask_resized = output.cpu().numpy()[0, 0]
# - 二值化:将概率图 (0-1) 转换为掩码图 (0或1)
# 这里的 0.5 是阈值,可以根据实际情况调整
pred_mask_binary = (pred_mask_resized > 0.5).astype(np.uint8)
# - 将掩码图恢复到原始图片的尺寸
predicted_mask = cv2.resize(pred_mask_binary, (original_w, original_h),
interpolation=cv2.INTER_NEAREST) * 255
# 6. (可选) 创建可视化叠加图
overlay_image = cv2.cvtColor(img_original, cv2.COLOR_GRAY2BGR)
overlay_image[predicted_mask == 255] = [0, 0, 255] # 在焊缝位置标记为红色 (BGR)
return predicted_mask, overlay_image
# --- 如何调用这个函数 ---
if __name__ == '__main__':
# 从 main.py 中复制模型定义代码到这里
# ...
MODEL_FILE = 'best_linknet_up_model_line2.pth'
IMAGE_TO_TEST = 'test/004/input/004.jpg' # <--- 修改为你的图片路径
try:
# 调用预测函数
final_mask, overlay = predict_single_image(MODEL_FILE, IMAGE_TO_TEST)
# 显示结果
# cv2.imshow('Original Image', cv2.imread(IMAGE_TO_TEST))
# cv2.imshow('Predicted Mask', final_mask)
# cv2.imshow('Overlay Result', overlay)
#
# # 按任意键退出
# cv2.waitKey(0)
# cv2.destroyAllWindows()
# 也可以保存结果
# cv2.imwrite('predicted_mask.png', final_mask)
cv2.imwrite('overlay_result.png', overlay)
except FileNotFoundError as e:
print(e)
except Exception as e:
print(f"An error occurred: {e}")