三维重构终版
This commit is contained in:
		
							
								
								
									
										273
									
								
								linknet/main.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										273
									
								
								linknet/main.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,273 @@
 | 
			
		||||
import os
 | 
			
		||||
import random
 | 
			
		||||
import cv2
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.optim as optim
 | 
			
		||||
from torch.utils.data import Dataset, DataLoader
 | 
			
		||||
from torchvision import models
 | 
			
		||||
from sklearn.model_selection import train_test_split
 | 
			
		||||
import matplotlib.pyplot as plt
 | 
			
		||||
import time
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# --- 1. 配置参数 ---
 | 
			
		||||
class Config:
 | 
			
		||||
    IMAGE_DIR = "data_up/images"
 | 
			
		||||
    MASK_DIR = "data_up/masks_line1"
 | 
			
		||||
    IMAGE_SIZE = 256  # 将所有图片缩放到 256x256
 | 
			
		||||
    BATCH_SIZE = 4
 | 
			
		||||
    EPOCHS = 50  # 训练轮数
 | 
			
		||||
    LEARNING_RATE = 1e-4
 | 
			
		||||
    TEST_SPLIT = 0.1  # 20% 的数据用作验证集
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# --- 2. 数据集加载和预处理 ---
 | 
			
		||||
class WeldSeamDataset(Dataset):
 | 
			
		||||
    def __init__(self, image_paths, mask_paths, size):
 | 
			
		||||
        self.image_paths = image_paths
 | 
			
		||||
        self.mask_paths = mask_paths
 | 
			
		||||
        self.size = size
 | 
			
		||||
 | 
			
		||||
    def __len__(self):
 | 
			
		||||
        return len(self.image_paths)
 | 
			
		||||
 | 
			
		||||
    def __getitem__(self, idx):
 | 
			
		||||
        # 读取图像
 | 
			
		||||
        img = cv2.imread(self.image_paths[idx], cv2.IMREAD_GRAYSCALE)
 | 
			
		||||
        img = cv2.resize(img, (self.size, self.size))
 | 
			
		||||
        img = img / 255.0  # 归一化到 [0, 1]
 | 
			
		||||
        img = np.expand_dims(img, axis=0)  # 增加通道维度 (H, W) -> (C, H, W)
 | 
			
		||||
        img_tensor = torch.from_numpy(img).float()
 | 
			
		||||
 | 
			
		||||
        # 读取掩码
 | 
			
		||||
        mask = cv2.imread(self.mask_paths[idx], cv2.IMREAD_GRAYSCALE)
 | 
			
		||||
        mask = cv2.resize(mask, (self.size, self.size))
 | 
			
		||||
        mask = mask / 255.0  # 归一化到 {0, 1}
 | 
			
		||||
        mask[mask > 0.5] = 1.0
 | 
			
		||||
        mask[mask <= 0.5] = 0.0
 | 
			
		||||
        mask = np.expand_dims(mask, axis=0)
 | 
			
		||||
        mask_tensor = torch.from_numpy(mask).float()
 | 
			
		||||
 | 
			
		||||
        return img_tensor, mask_tensor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# --- 3. LinkNet模型定义 ---
 | 
			
		||||
class DecoderBlock(nn.Module):
 | 
			
		||||
    def __init__(self, in_channels, out_channels):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.block = nn.Sequential(
 | 
			
		||||
            nn.Conv2d(in_channels, in_channels // 4, kernel_size=1),
 | 
			
		||||
            nn.ReLU(inplace=True),
 | 
			
		||||
            nn.ConvTranspose2d(in_channels // 4, in_channels // 4, kernel_size=2, stride=2),
 | 
			
		||||
            nn.ReLU(inplace=True),
 | 
			
		||||
            nn.Conv2d(in_channels // 4, out_channels, kernel_size=1),
 | 
			
		||||
            nn.ReLU(inplace=True)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        return self.block(x)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LinkNet(nn.Module):
 | 
			
		||||
    def __init__(self, num_classes=1):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        # 使用预训练的ResNet18作为编码器
 | 
			
		||||
        resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
 | 
			
		||||
 | 
			
		||||
        self.firstconv = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
 | 
			
		||||
        self.firstbn = resnet.bn1
 | 
			
		||||
        self.firstrelu = resnet.relu
 | 
			
		||||
        self.firstmaxpool = resnet.maxpool
 | 
			
		||||
 | 
			
		||||
        # 编码器层
 | 
			
		||||
        self.encoder1 = resnet.layer1
 | 
			
		||||
        self.encoder2 = resnet.layer2
 | 
			
		||||
        self.encoder3 = resnet.layer3
 | 
			
		||||
        self.encoder4 = resnet.layer4
 | 
			
		||||
 | 
			
		||||
        # 解码器层
 | 
			
		||||
        self.decoder4 = DecoderBlock(512, 256)
 | 
			
		||||
        self.decoder3 = DecoderBlock(256, 128)
 | 
			
		||||
        self.decoder2 = DecoderBlock(128, 64)
 | 
			
		||||
        self.decoder1 = DecoderBlock(64, 64)
 | 
			
		||||
 | 
			
		||||
        # 最终输出层
 | 
			
		||||
        self.final_deconv = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
 | 
			
		||||
        self.final_relu = nn.ReLU(inplace=True)
 | 
			
		||||
        self.final_conv = nn.Conv2d(32, num_classes, kernel_size=1)
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        # 编码器
 | 
			
		||||
        x = self.firstconv(x)
 | 
			
		||||
        x = self.firstbn(x)
 | 
			
		||||
        x = self.firstrelu(x)
 | 
			
		||||
        x = self.firstmaxpool(x)  # -> 64x64
 | 
			
		||||
        e1 = self.encoder1(x)  # -> 64x64
 | 
			
		||||
        e2 = self.encoder2(e1)  # -> 32x32
 | 
			
		||||
        e3 = self.encoder3(e2)  # -> 16x16
 | 
			
		||||
        e4 = self.encoder4(e3)  # -> 8x8
 | 
			
		||||
 | 
			
		||||
        # 解码器
 | 
			
		||||
        d4 = self.decoder4(e4) + e3  # -> 16x16
 | 
			
		||||
        d3 = self.decoder3(d4) + e2  # -> 32x32
 | 
			
		||||
        d2 = self.decoder2(d3) + e1  # -> 64x64
 | 
			
		||||
        d1 = self.decoder1(d2)  # -> 128x128
 | 
			
		||||
 | 
			
		||||
        f = self.final_deconv(d1)  # -> 256x256
 | 
			
		||||
        f = self.final_relu(f)
 | 
			
		||||
        f = self.final_conv(f)
 | 
			
		||||
 | 
			
		||||
        return torch.sigmoid(f)  # 使用Sigmoid输出概率图
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# --- 4. 损失函数 (Dice Loss + BCE Loss) ---
 | 
			
		||||
def dice_loss(pred, target, smooth=1.):
 | 
			
		||||
    pred = pred.contiguous()
 | 
			
		||||
    target = target.contiguous()
 | 
			
		||||
    intersection = (pred * target).sum(dim=2).sum(dim=2)
 | 
			
		||||
    loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth)))
 | 
			
		||||
    return loss.mean()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def bce_dice_loss(pred, target):
 | 
			
		||||
    bce = nn.BCELoss()(pred, target)
 | 
			
		||||
    dice = dice_loss(pred, target)
 | 
			
		||||
    return bce + dice
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# --- 5. 训练和评估 ---
 | 
			
		||||
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs):
 | 
			
		||||
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 | 
			
		||||
    model.to(device)
 | 
			
		||||
 | 
			
		||||
    print(f"Training on {device}")
 | 
			
		||||
 | 
			
		||||
    best_val_loss = float('inf')
 | 
			
		||||
 | 
			
		||||
    for epoch in range(num_epochs):
 | 
			
		||||
        start_time = time.time()
 | 
			
		||||
        model.train()
 | 
			
		||||
        running_loss = 0.0
 | 
			
		||||
 | 
			
		||||
        for images, masks in train_loader:
 | 
			
		||||
            images = images.to(device)
 | 
			
		||||
            masks = masks.to(device)
 | 
			
		||||
 | 
			
		||||
            optimizer.zero_grad()
 | 
			
		||||
            outputs = model(images)
 | 
			
		||||
            loss = criterion(outputs, masks)
 | 
			
		||||
            loss.backward()
 | 
			
		||||
            optimizer.step()
 | 
			
		||||
 | 
			
		||||
            running_loss += loss.item() * images.size(0)
 | 
			
		||||
 | 
			
		||||
        epoch_loss = running_loss / len(train_loader.dataset)
 | 
			
		||||
 | 
			
		||||
        # 验证
 | 
			
		||||
        model.eval()
 | 
			
		||||
        val_loss = 0.0
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            for images, masks in val_loader:
 | 
			
		||||
                images = images.to(device)
 | 
			
		||||
                masks = masks.to(device)
 | 
			
		||||
                outputs = model(images)
 | 
			
		||||
                loss = criterion(outputs, masks)
 | 
			
		||||
                val_loss += loss.item() * images.size(0)
 | 
			
		||||
 | 
			
		||||
        val_loss /= len(val_loader.dataset)
 | 
			
		||||
 | 
			
		||||
        duration = time.time() - start_time
 | 
			
		||||
        print(f"Epoch {epoch + 1}/{num_epochs}.. "
 | 
			
		||||
              f"Train Loss: {epoch_loss:.4f}.. "
 | 
			
		||||
              f"Val Loss: {val_loss:.4f}.. "
 | 
			
		||||
              f"Time: {duration:.2f}s")
 | 
			
		||||
 | 
			
		||||
        # 保存最佳模型
 | 
			
		||||
        if val_loss < best_val_loss:
 | 
			
		||||
            best_val_loss = val_loss
 | 
			
		||||
            torch.save(model.state_dict(), 'best_linknet_model.pth')
 | 
			
		||||
            print("Model saved!")
 | 
			
		||||
 | 
			
		||||
    print("Training complete.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def predict_and_visualize(model, image_path, model_path, size):
 | 
			
		||||
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 | 
			
		||||
    model.load_state_dict(torch.load(model_path, map_location=device))
 | 
			
		||||
    model.to(device)
 | 
			
		||||
    model.eval()
 | 
			
		||||
 | 
			
		||||
    # 加载和预处理单张图片
 | 
			
		||||
    img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
 | 
			
		||||
    original_size = (img.shape[1], img.shape[0])  # (width, height)
 | 
			
		||||
    img_resized = cv2.resize(img, (size, size))
 | 
			
		||||
    img_normalized = img_resized / 255.0
 | 
			
		||||
    img_tensor = torch.from_numpy(np.expand_dims(np.expand_dims(img_normalized, axis=0), axis=0)).float()
 | 
			
		||||
    img_tensor = img_tensor.to(device)
 | 
			
		||||
 | 
			
		||||
    with torch.no_grad():
 | 
			
		||||
        output = model(img_tensor)
 | 
			
		||||
 | 
			
		||||
    # 后处理
 | 
			
		||||
    pred_mask = output.cpu().numpy()[0, 0]  # 从 (B, C, H, W) -> (H, W)
 | 
			
		||||
    pred_mask = (pred_mask > 0.5).astype(np.uint8) * 255  # 二值化
 | 
			
		||||
    pred_mask = cv2.resize(pred_mask, original_size)  # 恢复到原始尺寸
 | 
			
		||||
 | 
			
		||||
    # 可视化
 | 
			
		||||
    plt.figure(figsize=(12, 6))
 | 
			
		||||
    plt.subplot(1, 3, 1)
 | 
			
		||||
    plt.title("Original Image")
 | 
			
		||||
    plt.imshow(cv2.cvtColor(img, cv2.COLOR_GRAY2RGB))
 | 
			
		||||
 | 
			
		||||
    plt.subplot(1, 3, 2)
 | 
			
		||||
    plt.title("Predicted Mask")
 | 
			
		||||
    plt.imshow(pred_mask, cmap='gray')
 | 
			
		||||
 | 
			
		||||
    # 将掩码叠加到原图
 | 
			
		||||
    overlay = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
 | 
			
		||||
    overlay[pred_mask == 255] = [255, 0, 0]  # 红色
 | 
			
		||||
    plt.subplot(1, 3, 3)
 | 
			
		||||
    plt.title("Overlay")
 | 
			
		||||
    plt.imshow(overlay)
 | 
			
		||||
 | 
			
		||||
    plt.show()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# --- 6. 主程序入口 ---
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    cfg = Config()
 | 
			
		||||
 | 
			
		||||
    # 准备数据集
 | 
			
		||||
    image_files = sorted([os.path.join(cfg.IMAGE_DIR, f) for f in os.listdir(cfg.IMAGE_DIR)])
 | 
			
		||||
    mask_files = sorted([os.path.join(cfg.MASK_DIR, f) for f in os.listdir(cfg.MASK_DIR)])
 | 
			
		||||
 | 
			
		||||
    # 划分训练集和验证集
 | 
			
		||||
    train_imgs, val_imgs, train_masks, val_masks = train_test_split(
 | 
			
		||||
        image_files, mask_files, test_size=cfg.TEST_SPLIT, random_state=42
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    train_dataset = WeldSeamDataset(train_imgs, train_masks, cfg.IMAGE_SIZE)
 | 
			
		||||
    val_dataset = WeldSeamDataset(val_imgs, val_masks, cfg.IMAGE_SIZE)
 | 
			
		||||
 | 
			
		||||
    train_loader = DataLoader(train_dataset, batch_size=cfg.BATCH_SIZE, shuffle=True)
 | 
			
		||||
    val_loader = DataLoader(val_dataset, batch_size=cfg.BATCH_SIZE, shuffle=False)
 | 
			
		||||
 | 
			
		||||
    # 初始化模型、损失函数和优化器
 | 
			
		||||
    model = LinkNet(num_classes=1)
 | 
			
		||||
    criterion = bce_dice_loss
 | 
			
		||||
    optimizer = optim.Adam(model.parameters(), lr=cfg.LEARNING_RATE)
 | 
			
		||||
 | 
			
		||||
    # --- 训练模型 ---
 | 
			
		||||
    # 如果你想开始训练,取消下面这行的注释
 | 
			
		||||
    train_model(model, train_loader, val_loader, criterion, optimizer, cfg.EPOCHS)
 | 
			
		||||
 | 
			
		||||
    # --- 使用训练好的模型进行预测 ---
 | 
			
		||||
    # 训练完成后,使用这个函数来测试
 | 
			
		||||
    # 确保 'best_linknet_bottom_model_line1.pth' 文件存在
 | 
			
		||||
    # print("\n--- Running Prediction ---")
 | 
			
		||||
    # # 随机选择一张验证集图片进行测试
 | 
			
		||||
    # test_image_path = random.choice(val_imgs)
 | 
			
		||||
    # print(f"Predicting on image: {test_image_path}")
 | 
			
		||||
    # predict_and_visualize(model, test_image_path, 'best_linknet_up_model_line1.pth', cfg.IMAGE_SIZE)
 | 
			
		||||
							
								
								
									
										116
									
								
								linknet/predict.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										116
									
								
								linknet/predict.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,116 @@
 | 
			
		||||
import torch
 | 
			
		||||
import cv2
 | 
			
		||||
import numpy as np
 | 
			
		||||
import matplotlib.pyplot as plt
 | 
			
		||||
 | 
			
		||||
from linknet.main import LinkNet
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# --- 首先,确保模型定义代码在这里 ---
 | 
			
		||||
# (你需要从之前的 main.py 文件中复制 LinkNet 和 DecoderBlock 类的定义)
 | 
			
		||||
# class DecoderBlock(nn.Module): ...
 | 
			
		||||
# class LinkNet(nn.Module): ...
 | 
			
		||||
# --- 假设模型定义代码已经复制过来了 ---
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def predict_single_image(model_path, image_path, image_size=256):
 | 
			
		||||
    """
 | 
			
		||||
    加载训练好的LinkNet模型,对单张新图片进行预测。
 | 
			
		||||
 | 
			
		||||
    参数:
 | 
			
		||||
    - model_path (str): 已保存的模型权重文件路径 (.pth)。
 | 
			
		||||
    - image_path (str): 待预测的图片文件路径。
 | 
			
		||||
    - image_size (int): 模型训练时使用的图片尺寸。
 | 
			
		||||
 | 
			
		||||
    返回:
 | 
			
		||||
    - predicted_mask (numpy.ndarray): 预测出的二值化掩码图,与原图尺寸相同,
 | 
			
		||||
                                      像素值为0或255。
 | 
			
		||||
    - overlay_image (numpy.ndarray): 将预测掩码(红色)叠加在原图上的结果图。
 | 
			
		||||
    """
 | 
			
		||||
    # 1. 设备选择
 | 
			
		||||
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 | 
			
		||||
    print(f"Using device: {device}")
 | 
			
		||||
 | 
			
		||||
    # 2. 加载模型结构并载入权重
 | 
			
		||||
    #   - 实例化你的模型结构
 | 
			
		||||
    #   - 使用 load_state_dict 载入已保存的权重
 | 
			
		||||
    #   - 将模型切换到评估模式 .eval()
 | 
			
		||||
    model = LinkNet(num_classes=1)
 | 
			
		||||
    model.load_state_dict(torch.load(model_path, map_location=device))
 | 
			
		||||
    model.to(device)
 | 
			
		||||
    model.eval()
 | 
			
		||||
 | 
			
		||||
    # 3. 加载并预处理图片
 | 
			
		||||
    #   - 读取图片(灰度模式)
 | 
			
		||||
    img_original = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
 | 
			
		||||
    if img_original is None:
 | 
			
		||||
        raise FileNotFoundError(f"Image not found at {image_path}")
 | 
			
		||||
 | 
			
		||||
    original_h, original_w = img_original.shape
 | 
			
		||||
 | 
			
		||||
    #   - 缩放到模型需要的尺寸 (和训练时一样)
 | 
			
		||||
    img_resized = cv2.resize(img_original, (image_size, image_size))
 | 
			
		||||
 | 
			
		||||
    #   - 归一化 (和训练时一样)
 | 
			
		||||
    img_normalized = img_resized / 255.0
 | 
			
		||||
 | 
			
		||||
    #   - 增加批次和通道维度 (H, W) -> (1, 1, H, W)
 | 
			
		||||
    img_tensor = torch.from_numpy(img_normalized).unsqueeze(0).unsqueeze(0).float()
 | 
			
		||||
 | 
			
		||||
    #   - 将Tensor发送到设备
 | 
			
		||||
    img_tensor = img_tensor.to(device)
 | 
			
		||||
 | 
			
		||||
    # 4. 模型推理
 | 
			
		||||
    #    使用 torch.no_grad() 来关闭梯度计算,节省显存并加速
 | 
			
		||||
    with torch.no_grad():
 | 
			
		||||
        output = model(img_tensor)
 | 
			
		||||
 | 
			
		||||
    # 5. 后处理
 | 
			
		||||
    #   - 将输出Tensor转回Numpy数组
 | 
			
		||||
    #   - 去掉批次和通道维度 (1, 1, H, W) -> (H, W)
 | 
			
		||||
    pred_mask_resized = output.cpu().numpy()[0, 0]
 | 
			
		||||
 | 
			
		||||
    #   - 二值化:将概率图 (0-1) 转换为掩码图 (0或1)
 | 
			
		||||
    #     这里的 0.5 是阈值,可以根据实际情况调整
 | 
			
		||||
    pred_mask_binary = (pred_mask_resized > 0.5).astype(np.uint8)
 | 
			
		||||
 | 
			
		||||
    #   - 将掩码图恢复到原始图片的尺寸
 | 
			
		||||
    predicted_mask = cv2.resize(pred_mask_binary, (original_w, original_h),
 | 
			
		||||
                                interpolation=cv2.INTER_NEAREST) * 255
 | 
			
		||||
 | 
			
		||||
    # 6. (可选) 创建可视化叠加图
 | 
			
		||||
    overlay_image = cv2.cvtColor(img_original, cv2.COLOR_GRAY2BGR)
 | 
			
		||||
    overlay_image[predicted_mask == 255] = [0, 0, 255]  # 在焊缝位置标记为红色 (BGR)
 | 
			
		||||
 | 
			
		||||
    return predicted_mask, overlay_image
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# --- 如何调用这个函数 ---
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    # 从 main.py 中复制模型定义代码到这里
 | 
			
		||||
    # ...
 | 
			
		||||
 | 
			
		||||
    MODEL_FILE = 'best_linknet_up_model_line2.pth'
 | 
			
		||||
    IMAGE_TO_TEST = 'test/004/input/004.jpg'  # <--- 修改为你的图片路径
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        # 调用预测函数
 | 
			
		||||
        final_mask, overlay = predict_single_image(MODEL_FILE, IMAGE_TO_TEST)
 | 
			
		||||
 | 
			
		||||
        # 显示结果
 | 
			
		||||
        # cv2.imshow('Original Image', cv2.imread(IMAGE_TO_TEST))
 | 
			
		||||
        # cv2.imshow('Predicted Mask', final_mask)
 | 
			
		||||
        # cv2.imshow('Overlay Result', overlay)
 | 
			
		||||
        #
 | 
			
		||||
        # # 按任意键退出
 | 
			
		||||
        # cv2.waitKey(0)
 | 
			
		||||
        # cv2.destroyAllWindows()
 | 
			
		||||
 | 
			
		||||
        # 也可以保存结果
 | 
			
		||||
        # cv2.imwrite('predicted_mask.png', final_mask)
 | 
			
		||||
        cv2.imwrite('overlay_result.png', overlay)
 | 
			
		||||
 | 
			
		||||
    except FileNotFoundError as e:
 | 
			
		||||
        print(e)
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        print(f"An error occurred: {e}")
 | 
			
		||||
		Reference in New Issue
	
	Block a user