I try to use U-Net for segmentation but the error appears in cat layer
This is down-conv and up-conv part
from torch.nn import Module
import torch.nn.functional as F
class DownConv(Module):
def __init__(self, in_feat, out_feat, drop_rate=0.4, bn_momentum=0.1):
super(DownConv, self).__init__()
self.conv1 = nn.Conv2d(in_feat, out_feat, kernel_size=3, padding=1)
self.conv1_bn = nn.BatchNorm2d(out_feat, momentum=bn_momentum)
self.conv1_drop = nn.Dropout2d(drop_rate)
self.conv2 = nn.Conv2d(out_feat, out_feat, kernel_size=3, padding=1)
self.conv2_bn = nn.BatchNorm2d(out_feat, momentum=bn_momentum)
self.conv2_drop = nn.Dropout2d(drop_rate)
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.conv1_bn(x)
x = self.conv1_drop(x)
x = F.relu(self.conv2(x))
x = self.conv2_bn(x)
x = self.conv2_drop(x)
return x
class UpConv(Module):
def __init__(self, in_feat, out_feat, drop_rate=0.4, bn_momentum=0.1):
super(UpConv, self).__init__()
self.up1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.downconv = DownConv(in_feat, out_feat, drop_rate, bn_momentum)
def forward(self, x, y):
x = self.up1(x)
x = torch.cat([x, y], dim=1)
x = self.downconv(x)
return x
This is model part
class Unet(Module):
def __init__(self, drop_rate=0.4, bn_momentum=0.1):
super(Unet, self).__init__()
#Downsampling path
self.conv1 = DownConv(1, 64, drop_rate, bn_momentum)
self.mp1 = nn.MaxPool2d(2)
self.conv2 = DownConv(64, 128, drop_rate, bn_momentum)
self.mp2 = nn.MaxPool2d(2)
self.conv3 = DownConv(128, 256, drop_rate, bn_momentum)
self.mp3 = nn.MaxPool2d(2)
# Bottleneck
self.conv4 = DownConv(256, 256, drop_rate, bn_momentum)
# Upsampling path
self.up1 = UpConv(512, 256, drop_rate, bn_momentum)
self.up2 = UpConv(384, 127, drop_rate, bn_momentum)
self.up3 = UpConv(191, 64, drop_rate, bn_momentum)
self.conv9 = nn.Conv2d(64, 1, kernel_size=3, padding=1)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.mp1(x1)
x3 = self.conv2(x2)
x4 = self.mp2(x3)
x5 = self.conv3(x4)
x6 = self.mp3(x5)
# Bottom
x7 = self.conv4(x6)
# print(x7.size(), x5.size())
# Up-sampling
x8 = self.up1(x7, x5)
print(x8.size(), x3.size())
x9 = self.up2(x8, x3)
x10 = self.up3(x9, x1)
x11 = self.conv9(x10)
preds = F.sigmoid(x11)
return preds
In the last line I have no idea why the size have changed to this and my input size have resample from [2,1,512,512] to [2,1,256,256]
torch.Size([2, 256, 64, 64]) torch.Size([2, 128, 128, 128])
torch.Size([2, 256, 64, 64]) torch.Size([2, 128, 128, 128])
torch.Size([2, 256, 64, 64]) torch.Size([2, 128, 128, 128])
torch.Size([2, 256, 64, 64]) torch.Size([2, 128, 128, 128])
torch.Size([2, 256, 64, 64]) torch.Size([2, 128, 128, 128])
torch.Size([2, 256, 64, 64]) torch.Size([2, 128, 128, 128])
torch.Size([2, 256, 64, 64]) torch.Size([2, 128, 128, 128])
torch.Size([2, 256, 64, 64]) torch.Size([2, 128, 128, 128])
torch.Size([2, 256, 64, 64]) torch.Size([2, 128, 128, 128])
torch.Size([2, 256, 64, 64]) torch.Size([2, 128, 128, 128])
torch.Size([2, 256, 86, 86]) torch.Size([2, 128, 173, 173])
error exception
32 def forward(self, x, y):
33 x = self.up1(x)
---> 34 x = torch.cat([x, y], dim=1)
35 x = self.downconv(x)
36 return x
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 1. Got 173 and 172 in dimension 2