Why is the output of a linear layer different when the batch size is 1?


I am trying to debug a network and noticed that for some reason, the outputs of a linear layer are slightly different depending on the batch size of the input tensor.

Minimal working example:

import torch
from torch import nn

ll = nn.Linear(4, 8)
data = torch.rand((16, 4))


This should give an output of


I know the different is really small numerically, but it is strange to me that when the batch size is 1 (in the last line, the size of the input is [1, 4] whereas the top line is [16, 4]), the representation seems to be different. Why is this happening? Is it possible that this could actually affect the model performance? Specifically, I’m seeing weird discrepancies in my model during training (when the batch size is 16) vs. evaluation (when the batch size is 1), so I am worried that this is causing it.



From the point of view of floating point arithmetic, these two numbers are actually the same.
We are doing some extra optimizations when the batch one is 1 leading to a different order for some accumulations. And since floating point accumulation is not associative, you see these kinds of artifacts.
Hope this helps.

1 Like

Thanks for the quick response! I suspected this was the case. The problem I am getting in my model must be coming from somewhere else then.

1 Like

I also had this issue recently. In my code, the input and out features size are much larger, and the result is different. I also print out the first 5 elements of each output.

import numpy as np
import torch
from torch import nn


net = nn.Linear(768, 768)
data = torch.rand((128, 768))



The output is:


[-0.23123368620872498, 0.982639491558075, -0.5261492133140564, 0.32300880551338196, 0.12303455173969269]
[-0.23123352229595184, 0.9826395511627197, -0.5261490345001221, 0.32300886511802673, 0.1230345219373703]
[-0.23123352229595184, 0.9826395511627197, -0.5261490345001221, 0.32300886511802673, 0.1230345219373703]
[-0.23123367130756378, 0.9826394319534302, -0.5261489748954773, 0.32300883531570435, 0.12303447723388672]
[-0.23123368620872498, 0.9826392531394958, -0.5261490941047668, 0.32300883531570435, 0.1230345070362091]

Here the output is different for the same example with different batch sizes. And if we look at the detailed elements in the output, the values may differ starting from the 6th decimal digit. In my application, some difference starts from the 5th decimal digit. How to explain such difference?


These errors are caused by the limited precision of floating point numbers.
If you need to increase the precision, you could use float64 to a performance penalty.

What is the preferred way of testing whether the different outputs are “close enough” in this case (i.e. the model does some optimizations internally, but the computational graph is actually the same)? Do I test it via torch.allclose() with default tolerance parameters?

There is no magic way to do this I’m afraid.

allclose will work for an op that is element-wise or with very small inputs.
But operations that contain reductions (sum, mm, conv, etc) will have a tendency to increase the error, the bigger the reduced dimension is.
So you can actually end up with very significant differences, I would even say arbitrarily large differences if you can make the inputs as big as you want.

1 Like

Thank you!
Just to make it clear: there is no option to disable such optimization related to singleton batch dimension, correct?

there is no option to disable such optimization related to singleton batch dimension, correct?

This is not an optimization, it is a limit of the floating point standard. Since you have limited precision, you have errors in your computations, and the more computations you perform, the bigger the error gets.

This is the “optimization” I was referring to. And it seems it cannot be switched off?

Ho right. The use of multiple cores on CPU can be switched off with torch.set_num_threads() but that will significantly slow down your program.
And even that is not guaranteed to solve all the issues as we never implement the batch implementation as just a for loop over the batch. That would be very very slow.

1 Like