三维重构终版
This commit is contained in:
		
							
								
								
									
										273
									
								
								linknet/main.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										273
									
								
								linknet/main.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,273 @@
 | 
			
		||||
import os
 | 
			
		||||
import random
 | 
			
		||||
import cv2
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.optim as optim
 | 
			
		||||
from torch.utils.data import Dataset, DataLoader
 | 
			
		||||
from torchvision import models
 | 
			
		||||
from sklearn.model_selection import train_test_split
 | 
			
		||||
import matplotlib.pyplot as plt
 | 
			
		||||
import time
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# --- 1. 配置参数 ---
 | 
			
		||||
class Config:
 | 
			
		||||
    IMAGE_DIR = "data_up/images"
 | 
			
		||||
    MASK_DIR = "data_up/masks_line1"
 | 
			
		||||
    IMAGE_SIZE = 256  # 将所有图片缩放到 256x256
 | 
			
		||||
    BATCH_SIZE = 4
 | 
			
		||||
    EPOCHS = 50  # 训练轮数
 | 
			
		||||
    LEARNING_RATE = 1e-4
 | 
			
		||||
    TEST_SPLIT = 0.1  # 20% 的数据用作验证集
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# --- 2. 数据集加载和预处理 ---
 | 
			
		||||
class WeldSeamDataset(Dataset):
 | 
			
		||||
    def __init__(self, image_paths, mask_paths, size):
 | 
			
		||||
        self.image_paths = image_paths
 | 
			
		||||
        self.mask_paths = mask_paths
 | 
			
		||||
        self.size = size
 | 
			
		||||
 | 
			
		||||
    def __len__(self):
 | 
			
		||||
        return len(self.image_paths)
 | 
			
		||||
 | 
			
		||||
    def __getitem__(self, idx):
 | 
			
		||||
        # 读取图像
 | 
			
		||||
        img = cv2.imread(self.image_paths[idx], cv2.IMREAD_GRAYSCALE)
 | 
			
		||||
        img = cv2.resize(img, (self.size, self.size))
 | 
			
		||||
        img = img / 255.0  # 归一化到 [0, 1]
 | 
			
		||||
        img = np.expand_dims(img, axis=0)  # 增加通道维度 (H, W) -> (C, H, W)
 | 
			
		||||
        img_tensor = torch.from_numpy(img).float()
 | 
			
		||||
 | 
			
		||||
        # 读取掩码
 | 
			
		||||
        mask = cv2.imread(self.mask_paths[idx], cv2.IMREAD_GRAYSCALE)
 | 
			
		||||
        mask = cv2.resize(mask, (self.size, self.size))
 | 
			
		||||
        mask = mask / 255.0  # 归一化到 {0, 1}
 | 
			
		||||
        mask[mask > 0.5] = 1.0
 | 
			
		||||
        mask[mask <= 0.5] = 0.0
 | 
			
		||||
        mask = np.expand_dims(mask, axis=0)
 | 
			
		||||
        mask_tensor = torch.from_numpy(mask).float()
 | 
			
		||||
 | 
			
		||||
        return img_tensor, mask_tensor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# --- 3. LinkNet模型定义 ---
 | 
			
		||||
class DecoderBlock(nn.Module):
 | 
			
		||||
    def __init__(self, in_channels, out_channels):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.block = nn.Sequential(
 | 
			
		||||
            nn.Conv2d(in_channels, in_channels // 4, kernel_size=1),
 | 
			
		||||
            nn.ReLU(inplace=True),
 | 
			
		||||
            nn.ConvTranspose2d(in_channels // 4, in_channels // 4, kernel_size=2, stride=2),
 | 
			
		||||
            nn.ReLU(inplace=True),
 | 
			
		||||
            nn.Conv2d(in_channels // 4, out_channels, kernel_size=1),
 | 
			
		||||
            nn.ReLU(inplace=True)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        return self.block(x)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LinkNet(nn.Module):
 | 
			
		||||
    def __init__(self, num_classes=1):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        # 使用预训练的ResNet18作为编码器
 | 
			
		||||
        resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
 | 
			
		||||
 | 
			
		||||
        self.firstconv = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
 | 
			
		||||
        self.firstbn = resnet.bn1
 | 
			
		||||
        self.firstrelu = resnet.relu
 | 
			
		||||
        self.firstmaxpool = resnet.maxpool
 | 
			
		||||
 | 
			
		||||
        # 编码器层
 | 
			
		||||
        self.encoder1 = resnet.layer1
 | 
			
		||||
        self.encoder2 = resnet.layer2
 | 
			
		||||
        self.encoder3 = resnet.layer3
 | 
			
		||||
        self.encoder4 = resnet.layer4
 | 
			
		||||
 | 
			
		||||
        # 解码器层
 | 
			
		||||
        self.decoder4 = DecoderBlock(512, 256)
 | 
			
		||||
        self.decoder3 = DecoderBlock(256, 128)
 | 
			
		||||
        self.decoder2 = DecoderBlock(128, 64)
 | 
			
		||||
        self.decoder1 = DecoderBlock(64, 64)
 | 
			
		||||
 | 
			
		||||
        # 最终输出层
 | 
			
		||||
        self.final_deconv = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
 | 
			
		||||
        self.final_relu = nn.ReLU(inplace=True)
 | 
			
		||||
        self.final_conv = nn.Conv2d(32, num_classes, kernel_size=1)
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        # 编码器
 | 
			
		||||
        x = self.firstconv(x)
 | 
			
		||||
        x = self.firstbn(x)
 | 
			
		||||
        x = self.firstrelu(x)
 | 
			
		||||
        x = self.firstmaxpool(x)  # -> 64x64
 | 
			
		||||
        e1 = self.encoder1(x)  # -> 64x64
 | 
			
		||||
        e2 = self.encoder2(e1)  # -> 32x32
 | 
			
		||||
        e3 = self.encoder3(e2)  # -> 16x16
 | 
			
		||||
        e4 = self.encoder4(e3)  # -> 8x8
 | 
			
		||||
 | 
			
		||||
        # 解码器
 | 
			
		||||
        d4 = self.decoder4(e4) + e3  # -> 16x16
 | 
			
		||||
        d3 = self.decoder3(d4) + e2  # -> 32x32
 | 
			
		||||
        d2 = self.decoder2(d3) + e1  # -> 64x64
 | 
			
		||||
        d1 = self.decoder1(d2)  # -> 128x128
 | 
			
		||||
 | 
			
		||||
        f = self.final_deconv(d1)  # -> 256x256
 | 
			
		||||
        f = self.final_relu(f)
 | 
			
		||||
        f = self.final_conv(f)
 | 
			
		||||
 | 
			
		||||
        return torch.sigmoid(f)  # 使用Sigmoid输出概率图
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# --- 4. 损失函数 (Dice Loss + BCE Loss) ---
 | 
			
		||||
def dice_loss(pred, target, smooth=1.):
 | 
			
		||||
    pred = pred.contiguous()
 | 
			
		||||
    target = target.contiguous()
 | 
			
		||||
    intersection = (pred * target).sum(dim=2).sum(dim=2)
 | 
			
		||||
    loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth)))
 | 
			
		||||
    return loss.mean()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def bce_dice_loss(pred, target):
 | 
			
		||||
    bce = nn.BCELoss()(pred, target)
 | 
			
		||||
    dice = dice_loss(pred, target)
 | 
			
		||||
    return bce + dice
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# --- 5. 训练和评估 ---
 | 
			
		||||
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs):
 | 
			
		||||
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 | 
			
		||||
    model.to(device)
 | 
			
		||||
 | 
			
		||||
    print(f"Training on {device}")
 | 
			
		||||
 | 
			
		||||
    best_val_loss = float('inf')
 | 
			
		||||
 | 
			
		||||
    for epoch in range(num_epochs):
 | 
			
		||||
        start_time = time.time()
 | 
			
		||||
        model.train()
 | 
			
		||||
        running_loss = 0.0
 | 
			
		||||
 | 
			
		||||
        for images, masks in train_loader:
 | 
			
		||||
            images = images.to(device)
 | 
			
		||||
            masks = masks.to(device)
 | 
			
		||||
 | 
			
		||||
            optimizer.zero_grad()
 | 
			
		||||
            outputs = model(images)
 | 
			
		||||
            loss = criterion(outputs, masks)
 | 
			
		||||
            loss.backward()
 | 
			
		||||
            optimizer.step()
 | 
			
		||||
 | 
			
		||||
            running_loss += loss.item() * images.size(0)
 | 
			
		||||
 | 
			
		||||
        epoch_loss = running_loss / len(train_loader.dataset)
 | 
			
		||||
 | 
			
		||||
        # 验证
 | 
			
		||||
        model.eval()
 | 
			
		||||
        val_loss = 0.0
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            for images, masks in val_loader:
 | 
			
		||||
                images = images.to(device)
 | 
			
		||||
                masks = masks.to(device)
 | 
			
		||||
                outputs = model(images)
 | 
			
		||||
                loss = criterion(outputs, masks)
 | 
			
		||||
                val_loss += loss.item() * images.size(0)
 | 
			
		||||
 | 
			
		||||
        val_loss /= len(val_loader.dataset)
 | 
			
		||||
 | 
			
		||||
        duration = time.time() - start_time
 | 
			
		||||
        print(f"Epoch {epoch + 1}/{num_epochs}.. "
 | 
			
		||||
              f"Train Loss: {epoch_loss:.4f}.. "
 | 
			
		||||
              f"Val Loss: {val_loss:.4f}.. "
 | 
			
		||||
              f"Time: {duration:.2f}s")
 | 
			
		||||
 | 
			
		||||
        # 保存最佳模型
 | 
			
		||||
        if val_loss < best_val_loss:
 | 
			
		||||
            best_val_loss = val_loss
 | 
			
		||||
            torch.save(model.state_dict(), 'best_linknet_model.pth')
 | 
			
		||||
            print("Model saved!")
 | 
			
		||||
 | 
			
		||||
    print("Training complete.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def predict_and_visualize(model, image_path, model_path, size):
 | 
			
		||||
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 | 
			
		||||
    model.load_state_dict(torch.load(model_path, map_location=device))
 | 
			
		||||
    model.to(device)
 | 
			
		||||
    model.eval()
 | 
			
		||||
 | 
			
		||||
    # 加载和预处理单张图片
 | 
			
		||||
    img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
 | 
			
		||||
    original_size = (img.shape[1], img.shape[0])  # (width, height)
 | 
			
		||||
    img_resized = cv2.resize(img, (size, size))
 | 
			
		||||
    img_normalized = img_resized / 255.0
 | 
			
		||||
    img_tensor = torch.from_numpy(np.expand_dims(np.expand_dims(img_normalized, axis=0), axis=0)).float()
 | 
			
		||||
    img_tensor = img_tensor.to(device)
 | 
			
		||||
 | 
			
		||||
    with torch.no_grad():
 | 
			
		||||
        output = model(img_tensor)
 | 
			
		||||
 | 
			
		||||
    # 后处理
 | 
			
		||||
    pred_mask = output.cpu().numpy()[0, 0]  # 从 (B, C, H, W) -> (H, W)
 | 
			
		||||
    pred_mask = (pred_mask > 0.5).astype(np.uint8) * 255  # 二值化
 | 
			
		||||
    pred_mask = cv2.resize(pred_mask, original_size)  # 恢复到原始尺寸
 | 
			
		||||
 | 
			
		||||
    # 可视化
 | 
			
		||||
    plt.figure(figsize=(12, 6))
 | 
			
		||||
    plt.subplot(1, 3, 1)
 | 
			
		||||
    plt.title("Original Image")
 | 
			
		||||
    plt.imshow(cv2.cvtColor(img, cv2.COLOR_GRAY2RGB))
 | 
			
		||||
 | 
			
		||||
    plt.subplot(1, 3, 2)
 | 
			
		||||
    plt.title("Predicted Mask")
 | 
			
		||||
    plt.imshow(pred_mask, cmap='gray')
 | 
			
		||||
 | 
			
		||||
    # 将掩码叠加到原图
 | 
			
		||||
    overlay = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
 | 
			
		||||
    overlay[pred_mask == 255] = [255, 0, 0]  # 红色
 | 
			
		||||
    plt.subplot(1, 3, 3)
 | 
			
		||||
    plt.title("Overlay")
 | 
			
		||||
    plt.imshow(overlay)
 | 
			
		||||
 | 
			
		||||
    plt.show()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# --- 6. 主程序入口 ---
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    cfg = Config()
 | 
			
		||||
 | 
			
		||||
    # 准备数据集
 | 
			
		||||
    image_files = sorted([os.path.join(cfg.IMAGE_DIR, f) for f in os.listdir(cfg.IMAGE_DIR)])
 | 
			
		||||
    mask_files = sorted([os.path.join(cfg.MASK_DIR, f) for f in os.listdir(cfg.MASK_DIR)])
 | 
			
		||||
 | 
			
		||||
    # 划分训练集和验证集
 | 
			
		||||
    train_imgs, val_imgs, train_masks, val_masks = train_test_split(
 | 
			
		||||
        image_files, mask_files, test_size=cfg.TEST_SPLIT, random_state=42
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    train_dataset = WeldSeamDataset(train_imgs, train_masks, cfg.IMAGE_SIZE)
 | 
			
		||||
    val_dataset = WeldSeamDataset(val_imgs, val_masks, cfg.IMAGE_SIZE)
 | 
			
		||||
 | 
			
		||||
    train_loader = DataLoader(train_dataset, batch_size=cfg.BATCH_SIZE, shuffle=True)
 | 
			
		||||
    val_loader = DataLoader(val_dataset, batch_size=cfg.BATCH_SIZE, shuffle=False)
 | 
			
		||||
 | 
			
		||||
    # 初始化模型、损失函数和优化器
 | 
			
		||||
    model = LinkNet(num_classes=1)
 | 
			
		||||
    criterion = bce_dice_loss
 | 
			
		||||
    optimizer = optim.Adam(model.parameters(), lr=cfg.LEARNING_RATE)
 | 
			
		||||
 | 
			
		||||
    # --- 训练模型 ---
 | 
			
		||||
    # 如果你想开始训练,取消下面这行的注释
 | 
			
		||||
    train_model(model, train_loader, val_loader, criterion, optimizer, cfg.EPOCHS)
 | 
			
		||||
 | 
			
		||||
    # --- 使用训练好的模型进行预测 ---
 | 
			
		||||
    # 训练完成后,使用这个函数来测试
 | 
			
		||||
    # 确保 'best_linknet_bottom_model_line1.pth' 文件存在
 | 
			
		||||
    # print("\n--- Running Prediction ---")
 | 
			
		||||
    # # 随机选择一张验证集图片进行测试
 | 
			
		||||
    # test_image_path = random.choice(val_imgs)
 | 
			
		||||
    # print(f"Predicting on image: {test_image_path}")
 | 
			
		||||
    # predict_and_visualize(model, test_image_path, 'best_linknet_up_model_line1.pth', cfg.IMAGE_SIZE)
 | 
			
		||||
		Reference in New Issue
	
	Block a user