三维重构终版
This commit is contained in:
70
3D_construction/script/linknet_model_def.py
Normal file
70
3D_construction/script/linknet_model_def.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision import models
|
||||
|
||||
|
||||
# 这个文件只包含 LinkNet 的模型结构定义
|
||||
# 从你的训练脚本中完整复制过来
|
||||
|
||||
class DecoderBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super().__init__()
|
||||
self.block = nn.Sequential(
|
||||
nn.Conv2d(in_channels, in_channels // 4, kernel_size=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.ConvTranspose2d(in_channels // 4, in_channels // 4, kernel_size=2, stride=2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(in_channels // 4, out_channels, kernel_size=1),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class LinkNet(nn.Module):
|
||||
def __init__(self, num_classes=1):
|
||||
super().__init__()
|
||||
# 使用预训练的ResNet18作为编码器
|
||||
# 注意:推理时可以不加载预训练权重,因为我们将加载自己训练好的完整模型权重
|
||||
resnet = models.resnet18() # weights=models.ResNet18_Weights.DEFAULT
|
||||
|
||||
# 你的模型是用单通道灰度图训练的
|
||||
self.firstconv = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
||||
self.firstbn = resnet.bn1
|
||||
self.firstrelu = resnet.relu
|
||||
self.firstmaxpool = resnet.maxpool
|
||||
# 编码器层
|
||||
self.encoder1 = resnet.layer1
|
||||
self.encoder2 = resnet.layer2
|
||||
self.encoder3 = resnet.layer3
|
||||
self.encoder4 = resnet.layer4
|
||||
# 解码器层
|
||||
self.decoder4 = DecoderBlock(512, 256)
|
||||
self.decoder3 = DecoderBlock(256, 128)
|
||||
self.decoder2 = DecoderBlock(128, 64)
|
||||
self.decoder1 = DecoderBlock(64, 64)
|
||||
# 最终输出层
|
||||
self.final_deconv = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
|
||||
self.final_relu = nn.ReLU(inplace=True)
|
||||
self.final_conv = nn.Conv2d(32, num_classes, kernel_size=1)
|
||||
|
||||
def forward(self, x):
|
||||
# 编码器
|
||||
x = self.firstconv(x)
|
||||
x = self.firstbn(x)
|
||||
x = self.firstrelu(x)
|
||||
x = self.firstmaxpool(x)
|
||||
e1 = self.encoder1(x)
|
||||
e2 = self.encoder2(e1)
|
||||
e3 = self.encoder3(e2)
|
||||
e4 = self.encoder4(e3)
|
||||
# 解码器
|
||||
d4 = self.decoder4(e4) + e3
|
||||
d3 = self.decoder3(d4) + e2
|
||||
d2 = self.decoder2(d3) + e1
|
||||
d1 = self.decoder1(d2)
|
||||
f = self.final_deconv(d1)
|
||||
f = self.final_relu(f)
|
||||
f = self.final_conv(f)
|
||||
return torch.sigmoid(f)
|
||||
Reference in New Issue
Block a user