Build a ResUnet model

Hello Community !
I recently work on a new computer vision subject and I’m learning pytorch for making my model from scratch. I did a UNet model, but I want to try a ResUnet model for a better result. I only work with a schema, and I’m not sure I understand it or know how to build it.
Can anyone help me ?

You could check the ResNet reference if you are stuck to see how the blocks are implemented.

By popularizing, I can see a ResUnet model like a blend between Unet and Resnet ? I don’t know Resnet model before this, thanks

I have some difficulties to make my ResUnet model. Can I get feedback on my code, please ?

Hi @Lau0711 ! Can you post your code and the problem you’re having?

@J_Johnson I made this :

class ResidualBlock(nn.Module):
    def __init__(self, input_dim, output_dim, stride, padding):
        super(ResidualBlock, self).__init__()

        self.conv_block = nn.Sequential(
            nn.Conv2d(
                input_dim, output_dim, kernel_size=3, stride=stride, padding=padding
            ),
            nn.BatchNorm2d(output_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
            nn.BatchNorm2d(output_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
        )
        self.conv_skip = nn.Sequential(
            nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1),
            nn.BatchNorm2d(output_dim),
        )

    def forward(self, x):

        return self.conv_block(x) + self.conv_skip(x)


class UpBlock(nn.Module):
    def __init__(self, input_dim, output_dim, kernel, stride):
        super(UpBlock, self).__init__()

        self.upsample = nn.ConvTranspose2d(
            input_dim, output_dim, kernel_size=kernel, stride=stride
        )

    def forward(self, x):
        return self.upsample(x)
    
    
class ResUnet(nn.Module):
    def __init__(self, in_channels, features=[64,128,256]):
        super(ResUnet, self).__init__()

        self.input_layer = nn.Sequential(
            nn.Conv2d(in_channels, features[0], kernel_size=3, padding=1),
            nn.BatchNorm2d(features[0]),
            nn.ReLU(),
            nn.Conv2d(features[0], features[0], kernel_size=3, padding=1),
        )
        self.input_skip = nn.Sequential(
            nn.Conv2d(in_channels, features[0], kernel_size=3, padding=1)
        )

        self.residual_conv1 = ResidualBlock(features[0], features[1], 2, 1)
        self.residual_conv2 = ResidualBlock(features[1], features[2], 2, 1)

        self.bridge = ResidualBlock(features[2], features[2], 2, 1)

        self.upsample1 = UpBlock(features[2], features[2], 3, 2)
        self.up_residual_conv1 = ResidualBlock(features[2] + features[1], features[1], 1, 1)

        self.output_layer = nn.Sequential(
            nn.Conv2d(features[0], 1, 1, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        x1 = self.input_layer(x) + self.input_skip(x)
        x2 = self.residual_conv1(x1)
        x3 = self.residual_conv2(x2)

        x4 = self.bridge(x3)

        x4 = self.upsample1(x4)
        x5 = torch.cat([x4, x3], dim=1)

        x6 = self.up_residual_conv1(x5)

        output = self.output_layer(x6)

        return output
    

if __name__=='__main__':

    dummy_batch = torch.randn((2,2,128,128))
    resmodel = ResUnet(in_channels=2, features=[64, 128, 256])
    result= resmodel(dummy_batch)

    print(result.shape)

And I have this error :

Exception has occurred: RuntimeError (note: full exception trace is shown but execution is paused at: _run_module_as_main)
Sizes of tensors must match except in dimension 1. Expected size 33 but got size 32 for tensor number 1 in the list.

Sorry, it must surely be stupid, but I’ve been learning to use pytorch for 4 days and I don’t understand everything

It’s just a size issue on your catenate for x5.

        x1 = self.input_layer(x) + self.input_skip(x)
        x2 = self.residual_conv1(x1)
        x3 = self.residual_conv2(x2)

        x4 = self.bridge(x3)

        x4 = self.upsample1(x4)
        print(x4.size(), x3.size()) # see both sizes are torch.Size([2, 256, 33, 33]) torch.Size([2, 256, 32, 32])
        x5 = torch.cat([x4, x3], dim=1) # this wants both sizes to match, except on the dim you are combining

        x6 = self.up_residual_conv1(x5)

        output = self.output_layer(x6)

        return output

So you will need to adjust your conv layers. so the outputs match.

I fixed this error, but now, I have this one :

Exception has occurred: RuntimeError       (note: full exception trace is shown but execution is paused at: _run_module_as_main)
Given groups=1, weight of size [128, 384, 3, 3], expected input[2, 512, 64, 64] to have 384 channels, but got 512 channels instead

Similar situation. This is just telling you that x5 has 512 channels. That’s because when you torch.cat x4 and x3, which are 256 and 256, on the channels dim, you have 512 channels. But your self.up_residual_conv1 block is set to take an input of features[2] + features[1] which is 384. And so you need to adjust the channels there to match what’s coming into that layer. For example:

self.up_residual_conv1 = ResidualBlock(features[2]*2 , features[1], 1, 1)

You’ll get a similar error on your self.output_layer, because the output of the above up_residual_conv1 is features[1] for channels, yet the input for the output_layer expects features[0] channels.

Okay, I see …
I check the different values and dimensions :
image
I will correct for x5 and x6
I have a doubt, is x4 useful?

It’s fairly standard in UNets to allow skip connections. This allows the model to choose where in the image to “pass through” and what needs more attention to detail during the “denoising diffusion” process.