三维重构终版

This commit is contained in:
2025-11-02 21:36:35 +08:00
parent f91b09da9d
commit f39009b853
126 changed files with 2870 additions and 2 deletions

116
linknet/predict.py Normal file
View File

@@ -0,0 +1,116 @@
import torch
import cv2
import numpy as np
import matplotlib.pyplot as plt
from linknet.main import LinkNet
# --- 首先,确保模型定义代码在这里 ---
# (你需要从之前的 main.py 文件中复制 LinkNet 和 DecoderBlock 类的定义)
# class DecoderBlock(nn.Module): ...
# class LinkNet(nn.Module): ...
# --- 假设模型定义代码已经复制过来了 ---
def predict_single_image(model_path, image_path, image_size=256):
"""
加载训练好的LinkNet模型对单张新图片进行预测。
参数:
- model_path (str): 已保存的模型权重文件路径 (.pth)。
- image_path (str): 待预测的图片文件路径。
- image_size (int): 模型训练时使用的图片尺寸。
返回:
- predicted_mask (numpy.ndarray): 预测出的二值化掩码图,与原图尺寸相同,
像素值为0或255。
- overlay_image (numpy.ndarray): 将预测掩码(红色)叠加在原图上的结果图。
"""
# 1. 设备选择
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# 2. 加载模型结构并载入权重
# - 实例化你的模型结构
# - 使用 load_state_dict 载入已保存的权重
# - 将模型切换到评估模式 .eval()
model = LinkNet(num_classes=1)
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()
# 3. 加载并预处理图片
# - 读取图片(灰度模式)
img_original = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
if img_original is None:
raise FileNotFoundError(f"Image not found at {image_path}")
original_h, original_w = img_original.shape
# - 缩放到模型需要的尺寸 (和训练时一样)
img_resized = cv2.resize(img_original, (image_size, image_size))
# - 归一化 (和训练时一样)
img_normalized = img_resized / 255.0
# - 增加批次和通道维度 (H, W) -> (1, 1, H, W)
img_tensor = torch.from_numpy(img_normalized).unsqueeze(0).unsqueeze(0).float()
# - 将Tensor发送到设备
img_tensor = img_tensor.to(device)
# 4. 模型推理
# 使用 torch.no_grad() 来关闭梯度计算,节省显存并加速
with torch.no_grad():
output = model(img_tensor)
# 5. 后处理
# - 将输出Tensor转回Numpy数组
# - 去掉批次和通道维度 (1, 1, H, W) -> (H, W)
pred_mask_resized = output.cpu().numpy()[0, 0]
# - 二值化:将概率图 (0-1) 转换为掩码图 (0或1)
# 这里的 0.5 是阈值,可以根据实际情况调整
pred_mask_binary = (pred_mask_resized > 0.5).astype(np.uint8)
# - 将掩码图恢复到原始图片的尺寸
predicted_mask = cv2.resize(pred_mask_binary, (original_w, original_h),
interpolation=cv2.INTER_NEAREST) * 255
# 6. (可选) 创建可视化叠加图
overlay_image = cv2.cvtColor(img_original, cv2.COLOR_GRAY2BGR)
overlay_image[predicted_mask == 255] = [0, 0, 255] # 在焊缝位置标记为红色 (BGR)
return predicted_mask, overlay_image
# --- 如何调用这个函数 ---
if __name__ == '__main__':
# 从 main.py 中复制模型定义代码到这里
# ...
MODEL_FILE = 'best_linknet_up_model_line2.pth'
IMAGE_TO_TEST = 'test/004/input/004.jpg' # <--- 修改为你的图片路径
try:
# 调用预测函数
final_mask, overlay = predict_single_image(MODEL_FILE, IMAGE_TO_TEST)
# 显示结果
# cv2.imshow('Original Image', cv2.imread(IMAGE_TO_TEST))
# cv2.imshow('Predicted Mask', final_mask)
# cv2.imshow('Overlay Result', overlay)
#
# # 按任意键退出
# cv2.waitKey(0)
# cv2.destroyAllWindows()
# 也可以保存结果
# cv2.imwrite('predicted_mask.png', final_mask)
cv2.imwrite('overlay_result.png', overlay)
except FileNotFoundError as e:
print(e)
except Exception as e:
print(f"An error occurred: {e}")