Defined two nn.ModuleList() in similar way but only one works fine

I’m building a multi-task-learning network (Segmentation and Depth) and for that I choose U-Net Architecture. I used common encoder , separate bottle-neck and decoder. The issue I found after debugging is; I have defined two nn.ModuleList() , one for decoder_seg and another for decoder_depth; in both of them I have added layers ; but while looping through layers in forward() method, it’s only looping through decoder_seg but not through decoder_depth (it’s length is showing 0), although both are build in similar way

Here’s the code:


import torch
import torch.nn as nn
import torchvision.transforms.functional as F
import torch.optim


class IntermediateBlocks(nn.Module):
    def __init__(self, block_in_channels, block_out_channels):
        super(IntermediateBlocks, self).__init__()
        self.block = nn.Sequential(

            nn.Conv2d(block_in_channels, block_out_channels,
                      kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(block_out_channels),
            nn.ReLU(inplace=True),

            nn.Conv2d(block_out_channels, block_out_channels,
                      kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(block_out_channels),
            nn.ReLU(inplace=True),

        )

    def forward(self, x):
        return self.block(x)


class DepthSegmentation(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, intermediate_channels=None):
        super(DepthSegmentation, self).__init__()
        self.out_channels = out_channels

        if intermediate_channels is None:
            intermediate_channels = [64, 128, 256, 512]

        """ ---------------- Down-Sampling Layers --------------- """
        self.encoder = nn.ModuleList()
        for num_channels in intermediate_channels:
            self.encoder.append(IntermediateBlocks(block_in_channels=in_channels, block_out_channels=num_channels))
            in_channels = num_channels
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        """ ---------------- Bottle Neck Layers ------------------ """
        # One for Segmentation
        self.bottleneck_seg = IntermediateBlocks(intermediate_channels[-1], intermediate_channels[-1] * 2)
        # one for Depth Estimation
        self.bottleneck_depth = IntermediateBlocks(intermediate_channels[-1], intermediate_channels[-1] * 2)

        """ ----------------- Up-Sampling Layers ---------------- """
        self.decoder_intermediate_channels = reversed(intermediate_channels)  # [512, 256, 128, 64]

        # for segmentation
        self.decoder_seg = nn.ModuleList()
        for num_channels in self.decoder_intermediate_channels:
            self.decoder_seg.append(
                nn.ConvTranspose2d(
                    num_channels * 2, num_channels, kernel_size=2, stride=2
                )
            )
            self.decoder_seg.append(IntermediateBlocks(num_channels * 2, num_channels))

        # for depth estimation
        self.decoder_depth = nn.ModuleList()
        for num_channels in self.decoder_intermediate_channels:
            self.decoder_depth.append(
                nn.ConvTranspose2d(
                    num_channels * 2, num_channels, kernel_size=2, stride=2
                )
            )
            self.decoder_depth.append(IntermediateBlocks(num_channels * 2, num_channels))

        """ -------------------- Final Layers ---------------------"""
        self.final_seg = nn.Conv2d(
            in_channels=intermediate_channels[0], out_channels=self.out_channels,
            kernel_size=1, stride=1, padding=0)
        self.final_depth = nn.Conv2d(
            in_channels=intermediate_channels[0], out_channels=self.out_channels,
            kernel_size=1, stride=1, padding=0)


    def forward(self, x):
        skip_connections_layers = []

        for layers in self.encoder:
            # First, processing it through Intermediate Block consisting of few conv layers
            x = layers(x)

            """ Since, the encoder was made from few IntermediateBlocks,--- 
            so for getting the skip_connections_layers (for concatenation in Decoder part),
            --- we are going to append the last layer of each IntermediateBlocks """
            skip_connections_layers.append(x)  # here the x supplied is from last layer of each Intermediate Block

            # MaxPooling is applied after every Intermediate Blocks (for Down-Sampling)
            x = self.pool(x)

        common_encoder_seg_output = x
        common_encoder_depth_output = x

        seg_out = self.bottleneck_seg(common_encoder_seg_output)
        depth_out = self.bottleneck_depth(common_encoder_depth_output)

        # As, every up-sampled layer need to be concatenated with last element(layer) present in
        # skip_connections_layers list so, it's better to reverse the list
        skip_connections_layers = skip_connections_layers[::-1]

        """ Since, self.decoder_seg is like ==>
        # [convTranspose2D, IntermediateBlocks, convTranspose2D, IntermediateBlocks, ..... ]
        # Here, the concatenation will be happening with only convTranspose2D layers, therefore
        # while looping, we have to use step = 2 """

        print(f'Length of decoder_depth: {len(self.decoder_depth)}')

        print(f'Length of decoder_seg: {len(self.decoder_seg)}',"\n")

        for i_seg in range(0, len(self.decoder_seg), 2):
            print(f"seg_out - Before {i_seg // 2} convTranspose2D: {seg_out.shape}")
            # First processing with convTranspose2D
            seg_out = self.decoder_seg[i_seg](seg_out)
            print(f"seg_out - After {i_seg // 2} convTranspose2D: {seg_out.shape}")
            required_skip_layer = skip_connections_layers[i_seg // 2]

            # While concatenation, we set dim = 1, because we want to do it along depth(channels)
            # Since a batch consist of  ==> (batch_size, channels_dim, height, width)
            # Also, we need to make sure the shape matches
            if seg_out.shape != required_skip_layer.shape:
                # [2:]  ==>  height, width
                seg_out = F.resize(seg_out, size=required_skip_layer.shape[2:])
            concatenated_layer_seg = torch.cat((required_skip_layer, seg_out), dim=1)

            print(f"seg_out - Before {i_seg // 2} IntermediateBlocks: {seg_out.shape}")
            # After that, processing with IntermediateBlocks
            seg_out = self.decoder_seg[i_seg + 1](concatenated_layer_seg)
            print(f"seg_out - After {i_seg // 2} IntermediateBlocks: {seg_out.shape}", '\n')


        print(f'seg_out - Before passing to final layer: {seg_out.shape}')

        print("\n",f'Length of decoder_depth: {len(self.decoder_depth)}')


        # Similarly for Depth Estimation head
        for i_depth in range(0, len(self.decoder_depth), 2):
            print(f"Before depth_out: {depth_out.shape}")
            # First processing with convTranspose2D
            depth_out = self.decoder_depth[i_depth](depth_out)
            print(f"After depth_out: {depth_out.shape}")
            required_skip_layer = skip_connections_layers[i_depth // 2]

            if depth_out.shape != required_skip_layer.shape:
                # [2:]  ==> height, width
                depth_out = F.resize(depth_out, size=required_skip_layer.shape[2:])
            concatenated_layer_depth = torch.cat((required_skip_layer, depth_out), dim=1)

            # After that, processing with IntermediateBlocks
            depth_out = self.decoder_depth[i_depth + 1](concatenated_layer_depth)
            print(f"After concatenated_layer_depth: {depth_out.shape}")


        return self.final_seg(seg_out), self.final_depth(depth_out)



# Dummy Input
input_batch = torch.randn((16, 3, 160, 160))
model = DepthSegmentation(in_channels=3, out_channels=3, intermediate_channels=[64,128,256,512])
seg_output, depth_output = model(input_batch)

print(seg_output.shape)
print(depth_output.shape)

This is the error I got :

RuntimeError: Given groups=1, weight of size [3, 64, 1, 1], expected input[16, 1024, 10, 10] to have 64 channels, but got 1024 channels instead.

This is full error with debugging output:

@ptrblck Please have a look.

self.final_depth(depth_out) fails since depth_out has a shape of [16, 1024, 10, 10] while self.final_depth expects an activation with 64 channels.