273 lines
9.0 KiB
Python
273 lines
9.0 KiB
Python
import os
|
|
import random
|
|
import cv2
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from torch.utils.data import Dataset, DataLoader
|
|
from torchvision import models
|
|
from sklearn.model_selection import train_test_split
|
|
import matplotlib.pyplot as plt
|
|
import time
|
|
|
|
|
|
# --- 1. 配置参数 ---
|
|
class Config:
|
|
IMAGE_DIR = "data_up/images"
|
|
MASK_DIR = "data_up/masks_line1"
|
|
IMAGE_SIZE = 256 # 将所有图片缩放到 256x256
|
|
BATCH_SIZE = 4
|
|
EPOCHS = 50 # 训练轮数
|
|
LEARNING_RATE = 1e-4
|
|
TEST_SPLIT = 0.1 # 20% 的数据用作验证集
|
|
|
|
|
|
# --- 2. 数据集加载和预处理 ---
|
|
class WeldSeamDataset(Dataset):
|
|
def __init__(self, image_paths, mask_paths, size):
|
|
self.image_paths = image_paths
|
|
self.mask_paths = mask_paths
|
|
self.size = size
|
|
|
|
def __len__(self):
|
|
return len(self.image_paths)
|
|
|
|
def __getitem__(self, idx):
|
|
# 读取图像
|
|
img = cv2.imread(self.image_paths[idx], cv2.IMREAD_GRAYSCALE)
|
|
img = cv2.resize(img, (self.size, self.size))
|
|
img = img / 255.0 # 归一化到 [0, 1]
|
|
img = np.expand_dims(img, axis=0) # 增加通道维度 (H, W) -> (C, H, W)
|
|
img_tensor = torch.from_numpy(img).float()
|
|
|
|
# 读取掩码
|
|
mask = cv2.imread(self.mask_paths[idx], cv2.IMREAD_GRAYSCALE)
|
|
mask = cv2.resize(mask, (self.size, self.size))
|
|
mask = mask / 255.0 # 归一化到 {0, 1}
|
|
mask[mask > 0.5] = 1.0
|
|
mask[mask <= 0.5] = 0.0
|
|
mask = np.expand_dims(mask, axis=0)
|
|
mask_tensor = torch.from_numpy(mask).float()
|
|
|
|
return img_tensor, mask_tensor
|
|
|
|
|
|
# --- 3. LinkNet模型定义 ---
|
|
class DecoderBlock(nn.Module):
|
|
def __init__(self, in_channels, out_channels):
|
|
super().__init__()
|
|
self.block = nn.Sequential(
|
|
nn.Conv2d(in_channels, in_channels // 4, kernel_size=1),
|
|
nn.ReLU(inplace=True),
|
|
nn.ConvTranspose2d(in_channels // 4, in_channels // 4, kernel_size=2, stride=2),
|
|
nn.ReLU(inplace=True),
|
|
nn.Conv2d(in_channels // 4, out_channels, kernel_size=1),
|
|
nn.ReLU(inplace=True)
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.block(x)
|
|
|
|
|
|
class LinkNet(nn.Module):
|
|
def __init__(self, num_classes=1):
|
|
super().__init__()
|
|
# 使用预训练的ResNet18作为编码器
|
|
resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
|
|
|
|
self.firstconv = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
|
self.firstbn = resnet.bn1
|
|
self.firstrelu = resnet.relu
|
|
self.firstmaxpool = resnet.maxpool
|
|
|
|
# 编码器层
|
|
self.encoder1 = resnet.layer1
|
|
self.encoder2 = resnet.layer2
|
|
self.encoder3 = resnet.layer3
|
|
self.encoder4 = resnet.layer4
|
|
|
|
# 解码器层
|
|
self.decoder4 = DecoderBlock(512, 256)
|
|
self.decoder3 = DecoderBlock(256, 128)
|
|
self.decoder2 = DecoderBlock(128, 64)
|
|
self.decoder1 = DecoderBlock(64, 64)
|
|
|
|
# 最终输出层
|
|
self.final_deconv = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
|
|
self.final_relu = nn.ReLU(inplace=True)
|
|
self.final_conv = nn.Conv2d(32, num_classes, kernel_size=1)
|
|
|
|
def forward(self, x):
|
|
# 编码器
|
|
x = self.firstconv(x)
|
|
x = self.firstbn(x)
|
|
x = self.firstrelu(x)
|
|
x = self.firstmaxpool(x) # -> 64x64
|
|
e1 = self.encoder1(x) # -> 64x64
|
|
e2 = self.encoder2(e1) # -> 32x32
|
|
e3 = self.encoder3(e2) # -> 16x16
|
|
e4 = self.encoder4(e3) # -> 8x8
|
|
|
|
# 解码器
|
|
d4 = self.decoder4(e4) + e3 # -> 16x16
|
|
d3 = self.decoder3(d4) + e2 # -> 32x32
|
|
d2 = self.decoder2(d3) + e1 # -> 64x64
|
|
d1 = self.decoder1(d2) # -> 128x128
|
|
|
|
f = self.final_deconv(d1) # -> 256x256
|
|
f = self.final_relu(f)
|
|
f = self.final_conv(f)
|
|
|
|
return torch.sigmoid(f) # 使用Sigmoid输出概率图
|
|
|
|
|
|
# --- 4. 损失函数 (Dice Loss + BCE Loss) ---
|
|
def dice_loss(pred, target, smooth=1.):
|
|
pred = pred.contiguous()
|
|
target = target.contiguous()
|
|
intersection = (pred * target).sum(dim=2).sum(dim=2)
|
|
loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth)))
|
|
return loss.mean()
|
|
|
|
|
|
def bce_dice_loss(pred, target):
|
|
bce = nn.BCELoss()(pred, target)
|
|
dice = dice_loss(pred, target)
|
|
return bce + dice
|
|
|
|
|
|
# --- 5. 训练和评估 ---
|
|
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs):
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
model.to(device)
|
|
|
|
print(f"Training on {device}")
|
|
|
|
best_val_loss = float('inf')
|
|
|
|
for epoch in range(num_epochs):
|
|
start_time = time.time()
|
|
model.train()
|
|
running_loss = 0.0
|
|
|
|
for images, masks in train_loader:
|
|
images = images.to(device)
|
|
masks = masks.to(device)
|
|
|
|
optimizer.zero_grad()
|
|
outputs = model(images)
|
|
loss = criterion(outputs, masks)
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
running_loss += loss.item() * images.size(0)
|
|
|
|
epoch_loss = running_loss / len(train_loader.dataset)
|
|
|
|
# 验证
|
|
model.eval()
|
|
val_loss = 0.0
|
|
with torch.no_grad():
|
|
for images, masks in val_loader:
|
|
images = images.to(device)
|
|
masks = masks.to(device)
|
|
outputs = model(images)
|
|
loss = criterion(outputs, masks)
|
|
val_loss += loss.item() * images.size(0)
|
|
|
|
val_loss /= len(val_loader.dataset)
|
|
|
|
duration = time.time() - start_time
|
|
print(f"Epoch {epoch + 1}/{num_epochs}.. "
|
|
f"Train Loss: {epoch_loss:.4f}.. "
|
|
f"Val Loss: {val_loss:.4f}.. "
|
|
f"Time: {duration:.2f}s")
|
|
|
|
# 保存最佳模型
|
|
if val_loss < best_val_loss:
|
|
best_val_loss = val_loss
|
|
torch.save(model.state_dict(), 'best_linknet_model.pth')
|
|
print("Model saved!")
|
|
|
|
print("Training complete.")
|
|
|
|
|
|
def predict_and_visualize(model, image_path, model_path, size):
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
model.load_state_dict(torch.load(model_path, map_location=device))
|
|
model.to(device)
|
|
model.eval()
|
|
|
|
# 加载和预处理单张图片
|
|
img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
|
|
original_size = (img.shape[1], img.shape[0]) # (width, height)
|
|
img_resized = cv2.resize(img, (size, size))
|
|
img_normalized = img_resized / 255.0
|
|
img_tensor = torch.from_numpy(np.expand_dims(np.expand_dims(img_normalized, axis=0), axis=0)).float()
|
|
img_tensor = img_tensor.to(device)
|
|
|
|
with torch.no_grad():
|
|
output = model(img_tensor)
|
|
|
|
# 后处理
|
|
pred_mask = output.cpu().numpy()[0, 0] # 从 (B, C, H, W) -> (H, W)
|
|
pred_mask = (pred_mask > 0.5).astype(np.uint8) * 255 # 二值化
|
|
pred_mask = cv2.resize(pred_mask, original_size) # 恢复到原始尺寸
|
|
|
|
# 可视化
|
|
plt.figure(figsize=(12, 6))
|
|
plt.subplot(1, 3, 1)
|
|
plt.title("Original Image")
|
|
plt.imshow(cv2.cvtColor(img, cv2.COLOR_GRAY2RGB))
|
|
|
|
plt.subplot(1, 3, 2)
|
|
plt.title("Predicted Mask")
|
|
plt.imshow(pred_mask, cmap='gray')
|
|
|
|
# 将掩码叠加到原图
|
|
overlay = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
|
overlay[pred_mask == 255] = [255, 0, 0] # 红色
|
|
plt.subplot(1, 3, 3)
|
|
plt.title("Overlay")
|
|
plt.imshow(overlay)
|
|
|
|
plt.show()
|
|
|
|
|
|
# --- 6. 主程序入口 ---
|
|
if __name__ == '__main__':
|
|
cfg = Config()
|
|
|
|
# 准备数据集
|
|
image_files = sorted([os.path.join(cfg.IMAGE_DIR, f) for f in os.listdir(cfg.IMAGE_DIR)])
|
|
mask_files = sorted([os.path.join(cfg.MASK_DIR, f) for f in os.listdir(cfg.MASK_DIR)])
|
|
|
|
# 划分训练集和验证集
|
|
train_imgs, val_imgs, train_masks, val_masks = train_test_split(
|
|
image_files, mask_files, test_size=cfg.TEST_SPLIT, random_state=42
|
|
)
|
|
|
|
train_dataset = WeldSeamDataset(train_imgs, train_masks, cfg.IMAGE_SIZE)
|
|
val_dataset = WeldSeamDataset(val_imgs, val_masks, cfg.IMAGE_SIZE)
|
|
|
|
train_loader = DataLoader(train_dataset, batch_size=cfg.BATCH_SIZE, shuffle=True)
|
|
val_loader = DataLoader(val_dataset, batch_size=cfg.BATCH_SIZE, shuffle=False)
|
|
|
|
# 初始化模型、损失函数和优化器
|
|
model = LinkNet(num_classes=1)
|
|
criterion = bce_dice_loss
|
|
optimizer = optim.Adam(model.parameters(), lr=cfg.LEARNING_RATE)
|
|
|
|
# --- 训练模型 ---
|
|
# 如果你想开始训练,取消下面这行的注释
|
|
train_model(model, train_loader, val_loader, criterion, optimizer, cfg.EPOCHS)
|
|
|
|
# --- 使用训练好的模型进行预测 ---
|
|
# 训练完成后,使用这个函数来测试
|
|
# 确保 'best_linknet_bottom_model_line1.pth' 文件存在
|
|
# print("\n--- Running Prediction ---")
|
|
# # 随机选择一张验证集图片进行测试
|
|
# test_image_path = random.choice(val_imgs)
|
|
# print(f"Predicting on image: {test_image_path}")
|
|
# predict_and_visualize(model, test_image_path, 'best_linknet_up_model_line1.pth', cfg.IMAGE_SIZE) |