I created a network with linear layers using moduleList, however when I run forward the shape of the output of each layer does not correspond to what I would expect. Concretely, I used the code below to create the network:
class decoder_with_fc_only(nn.Module):
def __init__(self, latent_size,layer_sizes, non_linearity=nn.ReLU):
super(decoder_with_fc_only, self).__init__()
n_layers = len(layer_sizes)
if n_layers < 2:
raise( ValueError('For an FC decoder with single a layer use simpler code.'))
size_list = [latent_size] + layer_sizes[0:-1]
list_len = len(size_list)
self.linear_layers = nn.ModuleList([non_linearity(nn.Linear(size_list[i-1], size_list[i])) \
for i in range(1,list_len) ])
self.final_layer = nn.Linear(layer_sizes[-2], layer_sizes[-1])
def forward(self,x):
for i, l in enumerate(self.linear_layers):
x = self.linear_layers[i](x)
x = self.final_layer(x)
x = torch.reshape(x, (-1,5000, 3))
return x
Upon construction, I pass in arguments latent size = 256, and layer_sizes = [256, 256, 512, 1025 ,15000].
The layers are then
layer1 with size (256 x 256)
layer2 with size (256 x 256)
layer3 with size (256 x 512)
layer4 with size (512 x 1025)
final_layer with size (1025 x 15000)
I would then expect that with batch size = 32 and an input of size 256, the output of the layers are
[32x256, 32x256, 32x512, 32x1025, 32x15000].
However, the sizes I get are
[32x256, 32x256, 32x256, 32x256, -].
The - is there since the final_layer won’t work due to incompatible input size.
However, with the following code, it works:
class decoder_with_fc_only(nn.Module):
def __init__(self, latent_size,layer_sizes, non_linearity=nn.ReLU):
super(decoder_with_fc_only, self).__init__()
n_layers = len(layer_sizes)
if n_layers < 2:
raise( ValueError('For an FC decoder with single a layer use simpler code.'))
size_list = [latent_size] + layer_sizes[0:-1]
list_len = len(size_list)
self.linear_layers = nn.ModuleList([nn.Linear(size_list[i-1], size_list[i]) \
for i in range(1,list_len) ])
self.final_layer = nn.Linear(layer_sizes[-2], layer_sizes[-1])
def forward(self,x):
for i, l in enumerate(self.linear_layers):
x = F.relu(l(x))
x = self.final_layer(x)
x = torch.reshape(x, (-1,5000, 3))
return x
I’m curious to see if anyone else has encountered this problem or can spot the bug. Thanks!