Dear programmers,
I am trying to define a simplified version of 3D UNet from the original one. In effect, when I tried to run the original UNet architecture, the memory is not enough. Thus, I have reduced the layers and filter sizes. However, when I try to train the model, the folowing error occurs.
File "/home/gaofei/newResearch/codes2/test1/project_fei-master/code/networks/unet.py", line 192, in forward
up3 = self.conv_128_128_UpConv(block2, block1)
File "/home/gaofei/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 541, in __call__
result = self.forward(*input, **kwargs)
File "/home/gaofei/newResearch/codes2/test1/project_fei-master/code/networks/unet.py", line 140, in forward
out = torch.cat((crop1, up), 1)
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 1. Got 238 and 57 in dimension 2 at /opt/conda/conda-bld/pytorch_1573049304260/work/aten/src/THC/generic/THCTensorMath.cu:71
I’m using a GPU with 10.92 GiB total capacity and my UNet codes are as follows
class Convolution(nn.Module):
def __init__(self, in_channels, out_channels):
super(Convolution, self).__init__()
self.convolution = nn.Conv3d(in_channels, out_channels, kernel_size=3)
self.batch = nn.BatchNorm3d(out_channels)
def forward(self, x):
out = F.relu(self.batch(self.convolution(x)))
return out
class UpConvolution(nn.Module):
def __init__(self, in_channels, out_channels):
super(UpConvolution, self).__init__()
self.up_convolution = nn.ConvTranspose3d(in_channels, out_channels,
kernel_size=2, stride=2)
#Center crop
def crop(self, bridge, up):
batch_size, n_channels, depth, layer_width, layer_height = bridge.size()
target_batch_size, target_n_channels, target_depth, target_layer_width, target_layer_height = up.size()
xy = (layer_width - target_layer_width) //2
zxy = (depth - target_depth) //2
# Returns a smaller block which is the same size than the block in the up part
return bridge[:, :, zxy:(zxy + target_depth), xy:(xy + target_layer_width), xy:(xy + target_layer_width)]
def forward(self, x, bridge):
up = self.up_convolution(x)
# Bridge is the opposite block of the up part
crop1 = self.crop(bridge, up)
out = torch.cat((crop1, up), 1)
return out
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
self.pooling = nn.MaxPool3d(kernel_size=2, stride=1)
#Down of unet
self.conv_1_32 = Convolution(1, 8)
self.conv_32_64 = Convolution(8, 16)
self.conv_64_64 = Convolution(16, 16)
self.conv_64_128 = Convolution(16, 32)
#self.conv_128_128 = Convolution(128, 128)
#self.conv_128_256 = Convolution(128, 256)
#self.conv_256_256 = Convolution(256, 256)
#self.conv_256_512 = Convolution(256, 512)
#Up of unet
#self.conv_512_512_UpConv = UpConvolution(512, 512)
#self.conv_768_256_Conv = Convolution(768, 256)
#self.conv_256_256_Conv = Convolution(256, 256)
#self.conv_256_256_UpConv = UpConvolution(256, 256)
#self.conv_384_128_Conv = Convolution(384, 128)
#self.conv_128_128_Conv = Convolution(128, 128)
self.conv_128_128_UpConv = UpConvolution(32, 32)
self.conv_192_64_Conv = Convolution(48, 16)
self.conv_64_64_Conv = Convolution(16, 16)
self.conv_64_1 = nn.Conv3d(16, 1, 1)
def forward(self, x):
start = self.conv_1_32(x)
block1 = self.conv_32_64(start)
block1_pool = self.pooling(block1)
block2 = self.conv_64_64(block1_pool)
block2 = self.conv_64_128(block2)
#block2_pool = self.pooling(block2)
#block3 = self.conv_128_128(block2_pool)
#block3 = self.conv_128_256(block3)
#block3_pool = self.pooling(block3)
#block4 = self.conv_256_256(block3_pool)
#block4 = self.conv_256_512(block4)
#up1 = self.conv_512_512_UpConv(block4, block3)
#up1_conv = self.conv_768_256_Conv(up1)
#up1_conv = self.conv_256_256_Conv(up1_conv)
#up2 = self.conv_256_256_UpConv(block3, block2)
#up2_conv = self.conv_384_128_Conv(up2)
#up2_conv = self.conv_128_128_Conv(up2_conv)
up3 = self.conv_128_128_UpConv(block2, block1)
up3_conv = self.conv_192_64_Conv(up3)
up3_conv = self.conv_64_64_Conv(up3_conv)
output = self.conv_64_1(up3_conv)
output = torch.sigmoid(output)
return output
Please, any suggestions and remarks would be highly appreciated
The bug happens in this line
up3 = self.conv_128_128_UpConv(block2, block1)