import torch import torch.nn as nn from torchvision import models # 这个文件只包含 LinkNet 的模型结构定义 # 从你的训练脚本中完整复制过来 class DecoderBlock(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.block = nn.Sequential( nn.Conv2d(in_channels, in_channels // 4, kernel_size=1), nn.ReLU(inplace=True), nn.ConvTranspose2d(in_channels // 4, in_channels // 4, kernel_size=2, stride=2), nn.ReLU(inplace=True), nn.Conv2d(in_channels // 4, out_channels, kernel_size=1), nn.ReLU(inplace=True) ) def forward(self, x): return self.block(x) class LinkNet(nn.Module): def __init__(self, num_classes=1): super().__init__() # 使用预训练的ResNet18作为编码器 # 注意:推理时可以不加载预训练权重,因为我们将加载自己训练好的完整模型权重 resnet = models.resnet18() # weights=models.ResNet18_Weights.DEFAULT # 你的模型是用单通道灰度图训练的 self.firstconv = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) self.firstbn = resnet.bn1 self.firstrelu = resnet.relu self.firstmaxpool = resnet.maxpool # 编码器层 self.encoder1 = resnet.layer1 self.encoder2 = resnet.layer2 self.encoder3 = resnet.layer3 self.encoder4 = resnet.layer4 # 解码器层 self.decoder4 = DecoderBlock(512, 256) self.decoder3 = DecoderBlock(256, 128) self.decoder2 = DecoderBlock(128, 64) self.decoder1 = DecoderBlock(64, 64) # 最终输出层 self.final_deconv = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2) self.final_relu = nn.ReLU(inplace=True) self.final_conv = nn.Conv2d(32, num_classes, kernel_size=1) def forward(self, x): # 编码器 x = self.firstconv(x) x = self.firstbn(x) x = self.firstrelu(x) x = self.firstmaxpool(x) e1 = self.encoder1(x) e2 = self.encoder2(e1) e3 = self.encoder3(e2) e4 = self.encoder4(e3) # 解码器 d4 = self.decoder4(e4) + e3 d3 = self.decoder3(d4) + e2 d2 = self.decoder2(d3) + e1 d1 = self.decoder1(d2) f = self.final_deconv(d1) f = self.final_relu(f) f = self.final_conv(f) return torch.sigmoid(f)