RuntimeError: size mismatch in ModuleList

Hi everyone,
I have a problem with a network and despite all my efforts, I can’t find a solution:

Here’s an excerpt from the code:
The network is constructed as follows:

from __future__ import print_function
import torch
import random
import math
import time
from torch import nn, optim


class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(14, 14) for i in range(12)]) # 12 inputs of 14 values in -> 14 out
        
        self.extras = nn.Sequential(
            nn.Linear(5, 15),
        ) 
        
        self.reseau = nn.Sequential(
            nn.Linear(14*12 + 15, 780),
            nn.PReLU(),
            nn.Linear(780, 520),
            nn.PReLU(),
            nn.Linear(520, 160),
            nn.Tanh(),
            nn.Linear(160, 12),            
        )

    def forward(self, x1,x2):
        for i, l in enumerate(self.linears):
            x1 = self.linears[i](x1)
                    
        x1 = x1.view(-1)      
        x2 = self.extras(x2)        
        x = torch.cat((x1, x2),dim=0)  
        x = self.reseau(x)     
        return x
    

model = MyModule()
model.cuda()

print(model)

x1 = torch.rand(12, 14)
x2 = torch.rand(5)  

output = model(x1.cuda(),x2.cuda())
print(output)

Which gives at runtime:

MyModule(
  (linears): ModuleList(
    (0): Linear(in_features=14, out_features=14, bias=True)
    (1): Linear(in_features=14, out_features=14, bias=True)
    (2): Linear(in_features=14, out_features=14, bias=True)
    (3): Linear(in_features=14, out_features=14, bias=True)
    (4): Linear(in_features=14, out_features=14, bias=True)
    (5): Linear(in_features=14, out_features=14, bias=True)
    (6): Linear(in_features=14, out_features=14, bias=True)
    (7): Linear(in_features=14, out_features=14, bias=True)
    (8): Linear(in_features=14, out_features=14, bias=True)
    (9): Linear(in_features=14, out_features=14, bias=True)
    (10): Linear(in_features=14, out_features=14, bias=True)
    (11): Linear(in_features=14, out_features=14, bias=True)
  )
  (extras): Sequential(
    (0): Linear(in_features=5, out_features=15, bias=True)
  )
  (reseau): Sequential(
    (0): Linear(in_features=183, out_features=780, bias=True)   # 183 = 12*14 + 15
    (1): PReLU()
    (2): Linear(in_features=780, out_features=520, bias=True)
    (3): PReLU()
    (4): Linear(in_features=520, out_features=160, bias=True)
    (5): Tanh()
    (6): Linear(in_features=160, out_features=12, bias=True)
  )
)
tensor([ 0.0053, -0.0556,  0.0377,  0.0156, -0.0305, -0.0137, -0.0470, -0.0658,
         0.0241, -0.0616,  0.0356,  0.0459], device='cuda:0',
       grad_fn=<AddBackward0>)

So far, so good. The network takes 14 values as input and has 14 values as output.
There is a second input (extras) which takes 5 values as input and a (5*3) = 15 values as output.
The two inputs are joined together with x = torch.cat((x1, x2),dim=0)

But there is one problem:

In my model, I have to multiply by 3 the outputs of each linear: the code is :

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(14, 14*3) for i in range(12)]) # 12 inputs of 14 values in -> 42 out
        
        self.extras = nn.Sequential(
            nn.Linear(5, 15),
        ) 
        
        self.reseau = nn.Sequential(
            nn.Linear(14*12*3 + 15, 780),    # 519 = 12*42 + 15
            nn.PReLU(),
            nn.Linear(780, 520),
            nn.PReLU(),
            nn.Linear(520, 160),
            nn.Tanh(),
            nn.Linear(160, 12),            
        )

    def forward(self, x1,x2):
        for i, l in enumerate(self.linears):
            x1 = self.linears[i](x1)
                    
        x1 = x1.view(-1)      
        x2 = self.extras(x2)        
        x = torch.cat((x1, x2),dim=0)  
        x = self.reseau(x)     
        return x
    

model = MyModule()
model.cuda()

print(model)

x1 = torch.rand(12, 14)
x2 = torch.rand(5)  

output = model(x1.cuda(),x2.cuda())
print(output)
      
MyModule(
  (linears): ModuleList(
    (0): Linear(in_features=14, out_features=42, bias=True)
    (1): Linear(in_features=14, out_features=42, bias=True)
    (2): Linear(in_features=14, out_features=42, bias=True)
    (3): Linear(in_features=14, out_features=42, bias=True)
    (4): Linear(in_features=14, out_features=42, bias=True)
    (5): Linear(in_features=14, out_features=42, bias=True)
    (6): Linear(in_features=14, out_features=42, bias=True)
    (7): Linear(in_features=14, out_features=42, bias=True)
    (8): Linear(in_features=14, out_features=42, bias=True)
    (9): Linear(in_features=14, out_features=42, bias=True)
    (10): Linear(in_features=14, out_features=42, bias=True)
    (11): Linear(in_features=14, out_features=42, bias=True)
  )
  (extras): Sequential(
    (0): Linear(in_features=5, out_features=15, bias=True)
  )
  (reseau): Sequential(
    (0): Linear(in_features=519, out_features=780, bias=True)  # 519 = 12*42 + 15
    (1): PReLU()
    (2): Linear(in_features=780, out_features=520, bias=True)
    (3): PReLU()
    (4): Linear(in_features=520, out_features=160, bias=True)
    (5): Tanh()
    (6): Linear(in_features=160, out_features=12, bias=True)
  )
)

The pattern seems correct to me. Each linear of the input takes 14 values at the input and a (14*3) = 42 values at the output.
But I have the following error at runtime:

RuntimeError                              Traceback (most recent call last)
<ipython-input-2-04b30ff8f25f> in <module>
     37 x2 = torch.rand(5)
     38 
---> 39 output = model(x1.cuda(),x2.cuda())
     40 

~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    539             result = self._slow_forward(*input, **kwargs)
    540         else:
--> 541             result = self.forward(*input, **kwargs)
    542         for hook in self._forward_hooks.values():
    543             hook_result = hook(self, input, result)

<ipython-input-2-04b30ff8f25f> in forward(self, x1, x2)
     20     def forward(self, x1,x2):
     21         for i, l in enumerate(self.linears):
---> 22             x1 = self.linears[i](x1)
     23 
     24         x1 = x1.view(-1)

~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    539             result = self._slow_forward(*input, **kwargs)
    540         else:
--> 541             result = self.forward(*input, **kwargs)
    542         for hook in self._forward_hooks.values():
    543             hook_result = hook(self, input, result)

~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/linear.py in forward(self, input)
     85 
     86     def forward(self, input):
---> 87         return F.linear(input, self.weight, self.bias)
     88 
     89     def extra_repr(self):

~/anaconda3/lib/python3.7/site-packages/torch/nn/functional.py in linear(input, weight, bias)
   1368     if input.dim() == 2 and bias is not None:
   1369         # fused op is marginally faster
-> 1370         ret = torch.addmm(bias, input, weight.t())
   1371     else:
   1372         output = input.matmul(weight.t())

RuntimeError: size mismatch, m1: [12 x 42], m2: [14 x 42] at /tmp/pip-req-build-4baxydiv/aten/src/THC/generic/THCTensorMathBlas.cu:290

Would this multiplication of the linear output in the ModuleList be forbidden? I don’t think so. I don’t know how to solve this problem.
I hope an experienced contributor can give me a lead or a solution.
Thank you for your help.

Yes, the modification in self.linears is creating the error.
The first layer will output an activation in the shape [batch_size, 14*3], while the next one expects an input of the shape [batch_size, 14], which fits the error message you’ve posted.

Since the layers are applied sequentially, you would have to make sure that the output size of one layer fits the expected input shape of the next one.

First of all, thank you for your prompt response!
I understand the problem better but I still don’t know how to implement my model.
My problem is that I have to implement in pyTorch a model that worked perfectly and gave excellent results with TorchNN and LUA language.
This is the TorchNN model:

nn.Sequential {
  [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> (7) -> (8) -> (9) -> (10) -> (11) -> output]
  (1): nn.ParallelTable {
    input
      |`-> (1): nn.Sequential {
      |      [input -> (1) -> (2) -> (3) -> output]
      |      (1): nn.SplitTable
      |      (2): nn.ParallelTable {
      |        input
      |          |`-> (1): nn.Linear(14 -> 42)
      |          |`-> (2): nn.Linear(14 -> 42)
      |          |`-> (3): nn.Linear(14 -> 42)
      |          |`-> (4): nn.Linear(14 -> 42)
      |          |`-> (5): nn.Linear(14 -> 42)
      |          |`-> (6): nn.Linear(14 -> 42)
      |          |`-> (7): nn.Linear(14 -> 42)
      |          |`-> (8): nn.Linear(14 -> 42)
      |          |`-> (9): nn.Linear(14 -> 42)
      |          |`-> (10): nn.Linear(14 -> 42)
      |          |`-> (11): nn.Linear(14 -> 42)
      |           `-> (12): nn.Linear(14 -> 42)
      |           ... -> output
      |      }
      |      (3): nn.JoinTable
      |    }
       `-> (2): nn.Sequential {
             [input -> (1) -> output]
             (1): nn.Linear(5 -> 15)
           }
       ... -> output
  }
  (2): nn.JoinTable
  (3): nn.Linear(519 -> 780)
  (4): nn.PReLU
  (5): nn.Linear(780 -> 520)
  (6): nn.PReLU
  (7): nn.Linear(520 -> 160)
  (10): nn.Tanh
  (11): nn.Linear(160 -> 12)

For various reasons I had to switch to pyTorch, but despite my efforts, the performance of the network has completely collapsed, because of the problem I’m stumbling upon: how to make my pyTorch model conform to the TorchNN nodel.
You understand that I’m not a great pyTorch expert but I’m eager to learn, so I’d be very grateful if you could help me some good advice, a contribution, or a piece of code that would get me out of this predicament.
Thank you in advance for your help.
Friendly
GP

nn.ParallelTable applies the

i -th member module to the i -th input and outputs a table

so it seems these layers are not used sequentially but in parallel (which also the name suggests :wink: ).
Given that, I assume you have 12 different inputs, so that you could feed each one to a layer?
Your input shape could therefore be [batch_size, 12, nb_features].

If that’s the case, you could use a loop via:

output = []
for idx in range(12):
    data = x1[:, idx]
    module = self.linears[idx]
    out = module(data)
    outputs.append(out)
outputs = torch.stack(outputs)

I’m not sure, in which dimension the JoinTable concatenates the output tensors, so you might want to use torch.stack(outputs, dim=?) (or torch.cat) with a specific dimension argument.

Could you post all shapes of all tensors from the Torch model, so that we could try to create the equivalent PyTorch model?

Thanks again for your response and your proposal to translate the Torch model into pyTorch. It gives me some hope… :grinning:

Here’s the complete definition of the TorchNN/LUA network:

require 'torch'
require 'nn'

local mlp1 = nn.Sequential()
	mlp1:add(nn.SplitTable(1))

local c = nn.ParallelTable()
-- input data	
	for i = 1,12 do 
		c:add(nn.Linear(14,42))
	end	
	
mlp1:add(c)
mlp1:add(nn.JoinTable(1))

-- input extra data
local ex = nn.Sequential()	
	ex:add(nn.Linear(5,15))	

local rs1 = nn.ParallelTable()	
rs1:add(mlp1)
rs1:add(ex)

local mlp = nn.Sequential()	

mlp:add(rs1)
mlp:add(nn.JoinTable(1))

mlp:add(nn.Linear(519,780))
mlp:add(nn.PReLU())
mlp:add(nn.Linear(780,520))
mlp:add(nn.PReLU())
mlp:add(nn.Linear(520,160))
mlp:add(nn.Tanh())
-- output
mlp:add(nn.Linear(160,12))

print(mlp)

nn.Sequential {
  [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> (7) -> (8) -> (9) -> output]
  (1): nn.ParallelTable {
    input
      |`-> (1): nn.Sequential {
      |      [input -> (1) -> (2) -> (3) -> output]
      |      (1): nn.SplitTable
      |      (2): nn.ParallelTable {
      |        input
      |          |`-> (1): nn.Linear(14 -> 42)
      |          |`-> (2): nn.Linear(14 -> 42)
      |          |`-> (3): nn.Linear(14 -> 42)
      |          |`-> (4): nn.Linear(14 -> 42)
      |          |`-> (5): nn.Linear(14 -> 42)
      |          |`-> (6): nn.Linear(14 -> 42)
      |          |`-> (7): nn.Linear(14 -> 42)
      |          |`-> (8): nn.Linear(14 -> 42)
      |          |`-> (9): nn.Linear(14 -> 42)
      |          |`-> (10): nn.Linear(14 -> 42)
      |          |`-> (11): nn.Linear(14 -> 42)
      |           `-> (12): nn.Linear(14 -> 42)
      |           ... -> output
      |      }
      |      (3): nn.JoinTable
      |    }
       `-> (2): nn.Sequential {
             [input -> (1) -> output]
             (1): nn.Linear(5 -> 15)
           }
       ... -> output
  }
  (2): nn.JoinTable
  (3): nn.Linear(519 -> 780)
  (4): nn.PReLU
  (5): nn.Linear(780 -> 520)
  (6): nn.PReLU
  (7): nn.Linear(520 -> 160)
  (8): nn.Tanh
  (9): nn.Linear(160 -> 12)
}


x1 = torch.rand(12, 14)
x2 = torch.rand(5)  
  	
			      					      	
local output = mlp:forward({x1,x2})     

print(output)

 0.0489
-0.0246
-0.0636
-0.0101
 0.0670
 0.0883
 0.1112
 0.0109
-0.1157
 0.0715
 0.0773
-0.0449
[torch.DoubleTensor of size 12]

In the Torch model we can easily make an nn.ParallelTable() of 12 nn.Linear(14,42) then add an input of the extra data: nn.Linear(5,15), do add() and nn.JoinTable().
I don’t know how to do the pyTorch equivalent. I used the nn.ModuleList() and the torch.cat(), but probably in the wrong way as you explained above, which causes very bad results.
I hope I have given you enough details to understand my problem.
Thank you for taking the time to answer me.
Sincerely
GP

Thanks for the complete code.
I’ve added some minor changes to your code.
Could you compare the outputs of your Lua model and this one:

lass MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(14, 42) for _ in range(12)])
        
        self.extras = nn.Sequential(
            nn.Linear(5, 15),
        ) 
        
        self.reseau = nn.Sequential(
            nn.Linear(14*12*3 + 15, 780),    # 519 = 12*42 + 15
            nn.PReLU(),
            nn.Linear(780, 520),
            nn.PReLU(),
            nn.Linear(520, 160),
            nn.Tanh(),
            nn.Linear(160, 12),            
        )

    def forward(self, x1,x2):
        outs = []
        for i, l in enumerate(self.linears):
            outs.append(self.linears[i](x1[:, i]))
        outs = torch.cat(outs, dim=1) # [batch_size, 504]
           
        x2 = self.extras(x2)        
        x = torch.cat((outs, x2),dim=1) # [batch_size, 519]
        x = self.reseau(x)     
        return x
    
model = MyModel()

batch_size = 5
x1 = torch.randn(batch_size, 12, 14)
x2 = torch.randn(batch_size, 5)

out = model(x1, x2)

The code is bright and very elegant!
I am very happy with this exchange which taught me a lot. I will be able to continue my research.
Thanks again for this help :grinning:
Friendly
GP