Encoder-Bottleneck-Decoder Unet Based Semantic Segmentation

I have Gray Image Size as: 256 x 256 . I am using Four convolutional Layers with 96 filters per layer. However, I am getting error: Please see my code

#-----------------------------------------------------------------------------#
#-----------------------------------------------------------------------------#
import torch
import torch.nn as nn
import cv2
import torch
from torchvision import transforms
from torchviz import make_dot
import numpy as np
#-----------------------------------------------------------------------------#
#-----------------------------------------------------------------------------#
class conv_block(nn.Module):
    
    def __init__(self, in_c, out_c, kernel_size=3):
        super().__init__()
        self.conv1 = nn.Conv2d(in_c, out_c, kernel_size, padding=0)
        self.in1   = nn.InstanceNorm2d(out_c)
        self.relu  = nn.ReLU(inplace=True)
        
    def forward(self, inputs):
        x = self.conv1(inputs)
        x = self.in1(x)
        x = self.relu(x)
        return x
    
class encoder_block(nn.Module):
    
    def __init__(self, in_c, out_c):
        super().__init__()
        self.conv = conv_block(in_c, out_c)
        self.pool = nn.MaxPool2d((2, 2))
        
    def forward(self, inputs):
        x = self.conv(inputs)
        p = self.pool(x)
        return x, p

class decoder_block(nn.Module):
    
    def __init__(self, in_c, out_c):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_c, out_c, kernel_size=3, stride=2, padding=0)
        self.conv = conv_block(out_c+ out_c, out_c)
        
    def forward(self, inputs, skip):
        x = self.up(inputs)
        x = torch.cat([x, skip], axis=1)
        x = self.conv(x)
        return x
        
    
class build_unet(nn.Module):
    
    def __init__(self):
        super().__init__()
#-----------------------------------------------------------------------------#
#                        ENCODER                                              #
#-----------------------------------------------------------------------------#
        self.e1 = encoder_block(1 ,96)
        self.e2 = encoder_block(96,96)
        self.e3 = encoder_block(96,96)
        self.e4 = encoder_block(96,96)
#-----------------------------------------------------------------------------#
#                 BOTTLENECK                                                  #
#-----------------------------------------------------------------------------#
        self.b = conv_block(96,96,1)
#-----------------------------------------------------------------------------#
#                   DECODER                                                   #
#-----------------------------------------------------------------------------#
        self.d1 = decoder_block(96,96)
        self.d2 = decoder_block(96,96)
        self.d3 = decoder_block(96,96)
        self.d4 = decoder_block(96,3)
#----------------------------------------------------#
#              """ Classifier """
#----------------------------------------------------#
        # self.outputs = nn.Conv2d(1, 96, kernel_size=1, padding=1)
#------------------------------------------------------------------------------#
#                   """ Encoder """
#-----------------------------------------------------------------------------#        
    def forward(self, inputs):
        s1, p1 = self.e1(inputs)
        s2, p2 = self.e2(p1)
        s3, p3 = self.e3(p2)
        s4, p4 = self.e4(p3)
#------------------------------------------------------------------------------#
#                     """ Bottleneck """
#-----------------------------------------------------------------------------#
        b = self.b(p4)
#------------------------------------------------------------------------------#
#                     """ Decoder """
#------------------------------------------------------------------------------#
        d1 = self.d1(b, s4)
        d2 = self.d2(d1, s3)
        d3 = self.d3(d2, s2)
        d4 = self.d4(d3, s1)
        return d4
#-------------------------------------------#
#          """ Classifier """               #
#-------------------------------------------#
        # outputs = self.outputs(d4)
        # return outputs
#-------------------------------------------------------------------------------#

if __name__ == "__main__":
    # inputs = torch.randn([1,1,256,256])
    image = cv2.imread (r'C:\Users\Idrees Bhat\Desktop\Research\Insha\Dataset\o_gray\1.png',0)
    convert_tensor = transforms.ToTensor()
    inputs=convert_tensor(image)
    # print (inputs.shape)
    # print (type(inputs)) # [1,256,256]
    model = build_unet()
    y = model (inputs)
    # print(y.shape)type or paste code here

#-----------------------------------------------------------------------------------------------------------------------
ERROR:
File ~\AppData\Local\anaconda3\Lib\site-packages\spyder_kernels\py3compat.py:356 in compat_exec
exec(code, globals, locals)

File c:\users\idrees bhat\desktop\insha\generator_aiwa.py:113
y = model (inputs)

File ~\AppData\Local\anaconda3\Lib\site-packages\torch\nn\modules\module.py:1518 in _wrapped_call_impl
return self._call_impl(*args, **kwargs)

File ~\AppData\Local\anaconda3\Lib\site-packages\torch\nn\modules\module.py:1527 in _call_impl
return forward_call(*args, **kwargs)

File c:\users\idrees bhat\desktop\insha\generator_aiwa.py:93 in forward
d1 = self.d1(b, s4)

File ~\AppData\Local\anaconda3\Lib\site-packages\torch\nn\modules\module.py:1518 in _wrapped_call_impl
return self._call_impl(*args, **kwargs)

File ~\AppData\Local\anaconda3\Lib\site-packages\torch\nn\modules\module.py:1527 in _call_impl
return forward_call(*args, **kwargs)

File c:\users\idrees bhat\desktop\insha\generator_aiwa.py:47 in forward
x = torch.cat([x, skip], axis=1)

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 29 but got size 28 for tensor number 1 in the list.

Hi Idrees!

I haven’t looked at your code in detail, but I believe what I am saying is correct.

There are two issues:

First, as your error message is telling you, the tensor from the skip connection
from the “downward.” contracting side of the U doesn’t have the same spatial
size as the tensor you are trying to cat() it with on the “upward,” expansive
side of the U. Following the original U-Net architecture, you need to crop the
skip-connection tensor to match the size of the “upward” tensor.

But there is some nuance to this.

First, you are using padding = 0 in your Conv2ds and ConvTranspose2ds.
This is correct (so leave it that way).

As your input tensor passes through the U-Net, its spatial size is being changed
in three ways: The convolutions (with a 3x3 kernel and no padding) trim a pixel
off of each edge. That is, each spatial dimension is reduced by 2. So your image
shrinks going both down and up the U. Your MaxPool2d layers reduce the
spatial size (of both dimensions) by a factor of two going down the U and the
stride = 2 ConvTranspose2d layers increase the size of the image by a
compensating factor of two going back up the U.

But because of the trimming, your skip-connection tensors are larger than than
the corresponding “upward” tensors, leading to the need to crop.

More subtly, the input to each MaxPool2d needs to have spatial sizes that are
multiples of two (so that the size can be reduced by a factor of two without any
misalignment or interpolation). This condition restricts the U-Net to accept input
tensors of specific spatial sizes.

To illustrate this, when you pass a 256x256 image to your U-Net, the first
Conv2d trims the image to 254x254 and the first MaxPool2d layer divides it
in two to 127x127. Then the second Conv2d trims it two 125x125, but these
spatial extents are not factors of two, so subsequent MaxPool2ds can’t be
performed cleanly without introducing artifacts. (Attempts to fix this up by
cropping or interpolating would also introduce artifacts.)

The sound way to address this is to work through your specific U-Net architecture
layer by layer to determine the “legal” input sizes and their corresponding output
sizes. Pad your ground-truth target data to the next largest legal output size, find
the legal input size that corresponds to that output size and reflection-pad your
input tensor to that legal size.

Lastly, you would typically use a per-pixel loss criterion for training (for example,
BCEWithLogitsLoss). It’s best to include in the loss only contributions from the
pixels in the original target tensor, ignoring the padded pixels.

“But,” you say, “this is all very complicated!” Yes, but it’s the right thing to do.
The inventors of U-Net made sure to get ti right in their original U-Net paper,
so you can too.

Best.

K. Frank

1 Like

Thank you Sir for your encouragement