nn.ModuleList not working as I think it should

I created a network with linear layers using moduleList, however when I run forward the shape of the output of each layer does not correspond to what I would expect. Concretely, I used the code below to create the network:

class decoder_with_fc_only(nn.Module):
    
    def __init__(self, latent_size,layer_sizes, non_linearity=nn.ReLU):
        super(decoder_with_fc_only, self).__init__()
        n_layers  = len(layer_sizes)
        if n_layers < 2:
            raise( ValueError('For an FC decoder with single a layer use simpler code.'))
            
        size_list = [latent_size] + layer_sizes[0:-1]
        list_len = len(size_list)
        self.linear_layers = nn.ModuleList([non_linearity(nn.Linear(size_list[i-1], size_list[i]))  \
                                             for i in range(1,list_len) ])
        
        self.final_layer = nn.Linear(layer_sizes[-2], layer_sizes[-1])
 
        
    def forward(self,x):
        for i, l in enumerate(self.linear_layers):
            x =  self.linear_layers[i](x)
        x = self.final_layer(x)
        x = torch.reshape(x, (-1,5000, 3))
        return x

Upon construction, I pass in arguments latent size = 256, and layer_sizes = [256, 256, 512, 1025 ,15000].

The layers are then
layer1 with size (256 x 256)
layer2 with size (256 x 256)
layer3 with size (256 x 512)
layer4 with size (512 x 1025)
final_layer with size (1025 x 15000)

I would then expect that with batch size = 32 and an input of size 256, the output of the layers are
[32x256, 32x256, 32x512, 32x1025, 32x15000].

However, the sizes I get are
[32x256, 32x256, 32x256, 32x256, -].
The - is there since the final_layer won’t work due to incompatible input size.

However, with the following code, it works:

class decoder_with_fc_only(nn.Module):
    
    def __init__(self, latent_size,layer_sizes, non_linearity=nn.ReLU):
        super(decoder_with_fc_only, self).__init__()
        n_layers  = len(layer_sizes)
        if n_layers < 2:
            raise( ValueError('For an FC decoder with single a layer use simpler code.'))
            
        size_list = [latent_size] + layer_sizes[0:-1]
        list_len = len(size_list)
        self.linear_layers = nn.ModuleList([nn.Linear(size_list[i-1], size_list[i])  \
                                            for i in range(1,list_len) ])
        self.final_layer = nn.Linear(layer_sizes[-2], layer_sizes[-1])

    def forward(self,x):
        for i, l in enumerate(self.linear_layers):
            x = F.relu(l(x))
        x = self.final_layer(x)
        x = torch.reshape(x, (-1,5000, 3))
        return x

I’m curious to see if anyone else has encountered this problem or can spot the bug. Thanks!

your non_linearity is a module nn.ReLU whose first constructor arg is inplace, a boolean flag. To chain modules together, use nn.Sequential instead.