How to resolve - RuntimeError: size mismatch, m1: [100 x 228], m2: [152 x 36]

Hi!

I am trying to build a Convolutional Neural Network and am facing a problem. The network has 3 Conv. layers, ReLU activated, pooled, batch normalized and then flattened. Here’s the model code:

in_features = 152 #in_features for Flatten(linear) layer calculated

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 2, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(num_features=6)
        self.conv2 = nn.Conv2d(6, 8, 2, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(num_features=8)
        self.conv3 = nn.Conv2d(8, 12, 2, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(num_features=12)

        self.fc1   = nn.Linear(in_features, 36)
        self.fc2   = nn.Linear(36, 3)
        
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.max_pool2d(out, 2)
        
        out = F.relu(self.bn2(self.conv2(out)))
        out = F.max_pool2d(out, 2)
        
        out = F.relu(self.bn3(self.conv3(out)))
        out = F.max_pool2d(out, 2)
        
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out))
        out = F.softmax(self.fc2(out))
        return out

pytorch_model = Net()

I am trying to do multiclass classification and so using ‘Adam’ Optimizer and ‘CrossEntropyLoss’ function.

My input is of the size: (143256,1, 150, 3) and I am trying to train using the following code:

batch_size = 100
epochs = 10

pytorch_model = pytorch_model.double()

for epoch in range(epochs):
    for i in tqdm(range(0, len(X_TRAIN), batch_size)):
        batch_x = X_TRAIN[i:i+batch_size]
        batch_y = Y_TRAIN[i:i+batch_size]
        
        optimizer.zero_grad() # clears all the gradients
        
        outputs = pytorch_model(batch_x.double())
        loss = criterion(outputs, batch_y)
        loss.backward()
        optimizer.step()
    print(loss)

And when trying to train the model I get this error:

RuntimeError                              Traceback (most recent call last)
<ipython-input-189-e1854e4cd77a> in <module>
---> 15         outputs = pytorch_model(batch_x.double())
~/anaconda3/envs/LearningML/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
--> 541             result = self.forward(*input, **kwargs)
<ipython-input-184-d3ab227f7914> in forward(self, x)
---> 27         out = F.relu(self.fc1(out))
~/anaconda3/envs/LearningML/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
--> 541             result = self.forward(*input, **kwargs)
~/anaconda3/envs/LearningML/lib/python3.6/site-packages/torch/nn/modules/linear.py in forward(self, input)
---> 87         return F.linear(input, self.weight, self.bias)
~/anaconda3/envs/LearningML/lib/python3.6/site-packages/torch/nn/functional.py in linear(input, weight, bias)
   1368     if input.dim() == 2 and bias is not None:
   1369         # fused op is marginally faster
-> 1370         ret = torch.addmm(bias, input, weight.t())
   1371     else:
   1372         output = input.matmul(weight.t())

RuntimeError: size mismatch, m1: [100 x 228], m2: [152 x 36] at /pytorch/aten/src/TH/generic/THTensorMath.cpp:197

I am fairly new to using PyTorch, so I know there’s something I am doing wrong here. Can anyone please explain what is that? Any help would be really appreciated…
Thanks!

It seems you’ve defined in_features as 152, which does not match the flattened shape of the input tensor to self.fc1.
If you add print statements right before the self.fc1 call, you can simply check the shape, which will be [batch_size, 228].
If you create the first linear layer as:

self.fc1 = nn.Linear(228, 36)

your code should work.

Also, nn.CrossEntropyLoss expects raw logits, so just remove the softmax at the end of your model and simply return the output of self.fc2.

1 Like

Follow your code.

layer output_size
input (N, 1, 150, 3)
conv1_bn_relu (N, 6, 151, 4)
max_pool2d (N, 6, 75, 2)
conv2_bn_relu (N, 8, 76, 3)
max_pool2d (N, 8, 38, 1)
conv3_bn_relu (N, 12, 39, 2)
max_pool2d (N, 12, 19, 1)
view (N, 228)

So, in_features in self.fc1 must be 228 instead of 152.

1 Like

Well, that really helped, and thanks for this…
BUT, I guess there’s something else too that I’ve probably done wrong…

Now, it seems like it’s getting the outputs fine but has problem in finding loss and is expecting scalar type Long. Here’s the traceback for the RuntimeError:

RuntimeError                              Traceback (most recent call last)
<ipython-input-213-e1854e4cd77a> in <module>
 ---> 16         loss = criterion(outputs, batch_y)
 ~/anaconda3/envs/LearningML/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
 --> 541             result = self.forward(*input, **kwargs)
~/anaconda3/envs/LearningML/lib/python3.6/site-packages/torch/nn/modules/loss.py in forward(self, input, target)
--> 916                                ignore_index=self.ignore_index, reduction=self.reduction)
~/anaconda3/envs/LearningML/lib/python3.6/site-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
-> 2009     return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
~/anaconda3/envs/LearningML/lib/python3.6/site-packages/torch/nn/functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
   1836                          .format(input.size(0), target.size(0)))
   1837     if dim == 2:
-> 1838         ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
   1839     elif dim == 4:
   1840         ret = torch._C._nn.nll_loss2d(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
RuntimeError: Expected object of scalar type Long but got scalar type Double for argument #2 'target' in call to _thnn_nll_loss_forward

That is really cool!
How can I print out all the layers like this?

I just… type them out manually…so that you can read more easily…
But if you are interesting in how to print out the output shape of each layer automatically, here’s a simple tutorial about forward hooks.

1 Like

Oh right!
Thank you for the tutorial on forward hooks as well… Really appreciated…! :smile:

The error, I replied above with:

RuntimeError: Expected object of scalar type Long but got scalar type Double for argument #2 'target' in call to _thnn_nll_loss_forward

was resolved by converting both the arguments that loss criterion needed, outputs and batch_y, to torch type long tensors. By changing the code from:

 outputs = pytorch_model(batch_x.double())
 loss = criterion(outputs, batch_y)

to

 outputs = pytorch_model(batch_x.double()).long()
 loss = criterion(outputs, batch_y.long())

And now, it gives this error:

RuntimeError: "log_softmax_lastdim_kernel_impl" not implemented for 'Long'

How can I resolve this?

Here’s the documentation about CrossEntropyLoss.

outputs: shape (N, 3), type Float or Double or Half, without softmax
batch_y: shape (N,), type Long, the labels.

CrossEntropyLoss equals to LogSoftmax + NLLLoss

1 Like

Thank you so much @Eta_C!!
It solved the problem… The model is now being trained…!

Since my problem is resolved, all thanks to PyTorch’s amazing discussion community… For anyone who happens to read this post in future for the same problem, here’s my final working model class and the training code:

in_features = 228 #in_features needed for Flattening(linear/fully connected) layer

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 2, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(num_features=6)
        self.conv2 = nn.Conv2d(6, 8, 2, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(num_features=8)
        self.conv3 = nn.Conv2d(8, 12, 2, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(num_features=12)

        self.fc1   = nn.Linear(in_features, 36)
        self.fc2   = nn.Linear(36, 3)
        
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.max_pool2d(out, 2)
        
        out = F.relu(self.bn2(self.conv2(out)))
        out = F.max_pool2d(out, 2)
        
        out = F.relu(self.bn3(self.conv3(out)))
        out = F.max_pool2d(out, 2)
        
        out = out.view(out.size(0), -1)
        print (out.size())  # helps in showing the size needed
        out = F.relu(self.fc1(out))
        out = self.fc2(out)
        return out

pytorch_model = Net()

Where the in_features can be calculated by following the code (e.g. in my case starting from input size > conv1 (bn1, relu) > max_pool2d > conv2 (bn2, relu) etc…). and using the following formula to find output size for each layer and finally the output size that nn.Linear layer needs (which led me to 228 (12*19)):

O=( (W−K+2P) /S ) + 1
where
W=Input height/width
K=Filter size
P=Padding
S=Stride
O=Output size

The criterion and optimizer:

weights = torch.tensor([0.05, 0.15, 0.8]).double() # assigning weights to classes
criterion = nn.CrossEntropyLoss(weight=weights)
optimizer = optim.Adam(pytorch_model.parameters())

And finally to train the model:

batch_size = 100
epochs = 10

# Uncomment the line below if data sets are type double. 
# Or can conversely convert data sets to type float and leave this commented

# pytorch_model = pytorch_model.double()

for epoch in tqdm(range(epochs)):
    for i in range(0, len(X_TRAIN), batch_size):
        batch_x = X_TRAIN[i:i+batch_size]
        batch_y = Y_TRAIN[i:i+batch_size]
        
        optimizer.zero_grad() # clears all the gradients
        
        outputs = pytorch_model(batch_x.double())
        loss = criterion(outputs, batch_y.long())
        loss.backward()
        optimizer.step()
    print(loss)

Making sure to follow the suggestions from @ptrblck and @Eta_C, I managed to run the training successfully.
Happy Learning!

Emmmm, in most cases, Float is enough…
pytorch_model.float() and batch_x.float()
Double would use more memory and make training slower…

1 Like

That was exactly what was happening!
It is utilizing a lot of memory and training a lot slower than usual. Even 10 epochs are taking about 50 minutes. Thanks for pointing this out…!