三维重构终版
This commit is contained in:
		
							
								
								
									
										116
									
								
								linknet/predict.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										116
									
								
								linknet/predict.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,116 @@
 | 
			
		||||
import torch
 | 
			
		||||
import cv2
 | 
			
		||||
import numpy as np
 | 
			
		||||
import matplotlib.pyplot as plt
 | 
			
		||||
 | 
			
		||||
from linknet.main import LinkNet
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# --- 首先,确保模型定义代码在这里 ---
 | 
			
		||||
# (你需要从之前的 main.py 文件中复制 LinkNet 和 DecoderBlock 类的定义)
 | 
			
		||||
# class DecoderBlock(nn.Module): ...
 | 
			
		||||
# class LinkNet(nn.Module): ...
 | 
			
		||||
# --- 假设模型定义代码已经复制过来了 ---
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def predict_single_image(model_path, image_path, image_size=256):
 | 
			
		||||
    """
 | 
			
		||||
    加载训练好的LinkNet模型,对单张新图片进行预测。
 | 
			
		||||
 | 
			
		||||
    参数:
 | 
			
		||||
    - model_path (str): 已保存的模型权重文件路径 (.pth)。
 | 
			
		||||
    - image_path (str): 待预测的图片文件路径。
 | 
			
		||||
    - image_size (int): 模型训练时使用的图片尺寸。
 | 
			
		||||
 | 
			
		||||
    返回:
 | 
			
		||||
    - predicted_mask (numpy.ndarray): 预测出的二值化掩码图,与原图尺寸相同,
 | 
			
		||||
                                      像素值为0或255。
 | 
			
		||||
    - overlay_image (numpy.ndarray): 将预测掩码(红色)叠加在原图上的结果图。
 | 
			
		||||
    """
 | 
			
		||||
    # 1. 设备选择
 | 
			
		||||
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 | 
			
		||||
    print(f"Using device: {device}")
 | 
			
		||||
 | 
			
		||||
    # 2. 加载模型结构并载入权重
 | 
			
		||||
    #   - 实例化你的模型结构
 | 
			
		||||
    #   - 使用 load_state_dict 载入已保存的权重
 | 
			
		||||
    #   - 将模型切换到评估模式 .eval()
 | 
			
		||||
    model = LinkNet(num_classes=1)
 | 
			
		||||
    model.load_state_dict(torch.load(model_path, map_location=device))
 | 
			
		||||
    model.to(device)
 | 
			
		||||
    model.eval()
 | 
			
		||||
 | 
			
		||||
    # 3. 加载并预处理图片
 | 
			
		||||
    #   - 读取图片(灰度模式)
 | 
			
		||||
    img_original = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
 | 
			
		||||
    if img_original is None:
 | 
			
		||||
        raise FileNotFoundError(f"Image not found at {image_path}")
 | 
			
		||||
 | 
			
		||||
    original_h, original_w = img_original.shape
 | 
			
		||||
 | 
			
		||||
    #   - 缩放到模型需要的尺寸 (和训练时一样)
 | 
			
		||||
    img_resized = cv2.resize(img_original, (image_size, image_size))
 | 
			
		||||
 | 
			
		||||
    #   - 归一化 (和训练时一样)
 | 
			
		||||
    img_normalized = img_resized / 255.0
 | 
			
		||||
 | 
			
		||||
    #   - 增加批次和通道维度 (H, W) -> (1, 1, H, W)
 | 
			
		||||
    img_tensor = torch.from_numpy(img_normalized).unsqueeze(0).unsqueeze(0).float()
 | 
			
		||||
 | 
			
		||||
    #   - 将Tensor发送到设备
 | 
			
		||||
    img_tensor = img_tensor.to(device)
 | 
			
		||||
 | 
			
		||||
    # 4. 模型推理
 | 
			
		||||
    #    使用 torch.no_grad() 来关闭梯度计算,节省显存并加速
 | 
			
		||||
    with torch.no_grad():
 | 
			
		||||
        output = model(img_tensor)
 | 
			
		||||
 | 
			
		||||
    # 5. 后处理
 | 
			
		||||
    #   - 将输出Tensor转回Numpy数组
 | 
			
		||||
    #   - 去掉批次和通道维度 (1, 1, H, W) -> (H, W)
 | 
			
		||||
    pred_mask_resized = output.cpu().numpy()[0, 0]
 | 
			
		||||
 | 
			
		||||
    #   - 二值化:将概率图 (0-1) 转换为掩码图 (0或1)
 | 
			
		||||
    #     这里的 0.5 是阈值,可以根据实际情况调整
 | 
			
		||||
    pred_mask_binary = (pred_mask_resized > 0.5).astype(np.uint8)
 | 
			
		||||
 | 
			
		||||
    #   - 将掩码图恢复到原始图片的尺寸
 | 
			
		||||
    predicted_mask = cv2.resize(pred_mask_binary, (original_w, original_h),
 | 
			
		||||
                                interpolation=cv2.INTER_NEAREST) * 255
 | 
			
		||||
 | 
			
		||||
    # 6. (可选) 创建可视化叠加图
 | 
			
		||||
    overlay_image = cv2.cvtColor(img_original, cv2.COLOR_GRAY2BGR)
 | 
			
		||||
    overlay_image[predicted_mask == 255] = [0, 0, 255]  # 在焊缝位置标记为红色 (BGR)
 | 
			
		||||
 | 
			
		||||
    return predicted_mask, overlay_image
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# --- 如何调用这个函数 ---
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    # 从 main.py 中复制模型定义代码到这里
 | 
			
		||||
    # ...
 | 
			
		||||
 | 
			
		||||
    MODEL_FILE = 'best_linknet_up_model_line2.pth'
 | 
			
		||||
    IMAGE_TO_TEST = 'test/004/input/004.jpg'  # <--- 修改为你的图片路径
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        # 调用预测函数
 | 
			
		||||
        final_mask, overlay = predict_single_image(MODEL_FILE, IMAGE_TO_TEST)
 | 
			
		||||
 | 
			
		||||
        # 显示结果
 | 
			
		||||
        # cv2.imshow('Original Image', cv2.imread(IMAGE_TO_TEST))
 | 
			
		||||
        # cv2.imshow('Predicted Mask', final_mask)
 | 
			
		||||
        # cv2.imshow('Overlay Result', overlay)
 | 
			
		||||
        #
 | 
			
		||||
        # # 按任意键退出
 | 
			
		||||
        # cv2.waitKey(0)
 | 
			
		||||
        # cv2.destroyAllWindows()
 | 
			
		||||
 | 
			
		||||
        # 也可以保存结果
 | 
			
		||||
        # cv2.imwrite('predicted_mask.png', final_mask)
 | 
			
		||||
        cv2.imwrite('overlay_result.png', overlay)
 | 
			
		||||
 | 
			
		||||
    except FileNotFoundError as e:
 | 
			
		||||
        print(e)
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        print(f"An error occurred: {e}")
 | 
			
		||||
		Reference in New Issue
	
	Block a user