I am trying to debug a network and noticed that for some reason, the outputs of a linear layer are slightly different depending on the batch size of the input tensor.
Minimal working example:
import torch
from torch import nn
torch.manual_seed(72)
ll = nn.Linear(4, 8)
data = torch.rand((16, 4))
print(ll(data)[0].mean().item())
print(ll(data[:8])[0].mean().item())
print(ll(data[:4])[0].mean().item())
print(ll(data[:2])[0].mean().item())
print(ll(data[:1])[0].mean().item())
I know the different is really small numerically, but it is strange to me that when the batch size is 1 (in the last line, the size of the input is [1, 4] whereas the top line is [16, 4]), the representation seems to be different. Why is this happening? Is it possible that this could actually affect the model performance? Specifically, I’m seeing weird discrepancies in my model during training (when the batch size is 16) vs. evaluation (when the batch size is 1), so I am worried that this is causing it.
From the point of view of floating point arithmetic, these two numbers are actually the same.
We are doing some extra optimizations when the batch one is 1 leading to a different order for some accumulations. And since floating point accumulation is not associative, you see these kinds of artifacts.
Hope this helps.
I also had this issue recently. In my code, the input and out features size are much larger, and the result is different. I also print out the first 5 elements of each output.
import numpy as np
import torch
from torch import nn
torch.manual_seed(72)
net = nn.Linear(768, 768)
data = torch.rand((128, 768))
print(net(data)[0].mean().item())
print(net(data[:8])[0].mean().item())
print(net(data[:4])[0].mean().item())
print(net(data[:2])[0].mean().item())
print(net(data[:1])[0].mean().item())
print(net(data)[0][:5].tolist())
print(net(data[:8])[0][:5].tolist())
print(net(data[:4])[0][:5].tolist())
print(net(data[:2])[0][:5].tolist())
print(net(data[:1])[0][:5].tolist())
Here the output is different for the same example with different batch sizes. And if we look at the detailed elements in the output, the values may differ starting from the 6th decimal digit. In my application, some difference starts from the 5th decimal digit. How to explain such difference?
These errors are caused by the limited precision of floating point numbers.
If you need to increase the precision, you could use float64 to a performance penalty.
What is the preferred way of testing whether the different outputs are “close enough” in this case (i.e. the model does some optimizations internally, but the computational graph is actually the same)? Do I test it via torch.allclose() with default tolerance parameters?
allclose will work for an op that is element-wise or with very small inputs.
But operations that contain reductions (sum, mm, conv, etc) will have a tendency to increase the error, the bigger the reduced dimension is.
So you can actually end up with very significant differences, I would even say arbitrarily large differences if you can make the inputs as big as you want.
there is no option to disable such optimization related to singleton batch dimension, correct?
This is not an optimization, it is a limit of the floating point standard. Since you have limited precision, you have errors in your computations, and the more computations you perform, the bigger the error gets.
Ho right. The use of multiple cores on CPU can be switched off with torch.set_num_threads() but that will significantly slow down your program.
And even that is not guaranteed to solve all the issues as we never implement the batch implementation as just a for loop over the batch. That would be very very slow.