Numerical error between batch and single instance computation

I found that there is some insignificant numerical error between the result obtained from batch mode computation and those obtained from iterating over the instances in the batch. For example:

batch_size = 32
input_size = 128
output_size = 256
linear = torch.nn.Linear(input_size,output_size)
x = torch.rand(batch_size, input_size)
output = linear(x)
output1 = linear(x[:1])
output2 = linear(x[:(batch_size//2)])

print(10e6*torch.max(torch.abs(output[:1] - output1)))
print(10e6*torch.max(torch.abs(output[:(batch_size//2)] - output2)))

The above code may print out something like:

tensor(1.4901, grad_fn=<MulBackward>) # a non-zero value, i.e., there is some numerical error
tensor(0., grad_fn=<MulBackward>) # a zero value, i.e, there is no error

Where does the above error come from? Any help is highly appreciated.

Hi,

The precision of a floating point number is around 1e-6. So anything smaller is going to be noise.
Also floating point operations are not really commutative or associative and will create very small error if you do them in a different order.
So this is expected behavior I’m afraid. You can use double precision numbers if you require more precision.

Thanks for the great answer.

Hi Tuan Anh (and Alban)!

I do find this a bit odd.

It is true, as Alban points out, you can’t really complain and call
this an error. The results of the two versions of the computation
do agree to within 32-bit floating-point round-off error.

But floating-point computations don’t just give you random
differences (and they’re not allowed to). Again, as Alban notes,
doing a floating-point computation in a different (but mathematically
equivalent) order can lead to a round-off error difference in the result.

(A minor quibble: Floating-point operations are not associative,
but are commutative – at least in any sensible world, e.g., IEEE.)

My problem is that I can’t cook up any good reason the two versions
of the computation should be being done in different orders.

For a little more fun, I tweaked your sample code to run on my decrepit
pytorch 0.3.0 installation, and added to it a few more tests.

The two highlights

When done with 64-bit doubles, no difference arises.

When done with 32-bit floats, a difference only arises when batches
of size 1, 2, and 3 are passed into linear.

(When run on the gpu instead of the cpu I get very similar results,
except that the non-zero difference only shows up for a batch of
size 1.)

My complete script and output for these tests appear below.

Now it is true that our complaint will not stand up in a court of
floating-point law, but I am curious where this behavior comes
from and if there is really any good reason for it. Perhaps some
experts could chime in with more insight.

There must be a sensible reason, no?

Surely pytorch is not – dare I say it? – random …

Best.

K. Frank

Script:

import torch
print (torch.__version__)

torch.manual_seed (2019)

gpu = False
# gpu = True

print ('gpu =', gpu)

batch_size = 32
input_size = 128
output_size = 256
linear = torch.nn.Linear(input_size,output_size)
x = torch.autograd.Variable (torch.rand(batch_size, input_size))
if gpu:
    linear.cuda()
    x = x.cuda()
output = linear(x)
output1 = linear(x[:1])
output2 = linear(x[:(batch_size//2)])

print(10e6*torch.max(torch.abs(output[:1] - output1)))
print(10e6*torch.max(torch.abs(output[:(batch_size//2)] - output2)))

for  n in range (1, 5):
    outputn = linear (x[:n])
    print ('n =', n, ', diff =',10e6*torch.max(torch.abs(output[:1] - outputn[:1])))
    print ('n =', n, ', diff =',10e6*torch.max(torch.abs(output1 - outputn[:1])))

diffcount = 0
for  n in range (1, batch_size):
    outputn = linear (x[:n])
    maxdiff = torch.max(torch.abs(output[:1] - outputn[:1])).data[0]
    if  maxdiff != 0.0:
        print  ('n =', n, ', maxdiff =', maxdiff)
        diffcount += 1
print ('diffcount =', diffcount)

dlinear = torch.nn.Linear(input_size,output_size)
dlinear.weight = torch.nn.parameter.Parameter (linear.weight.data.double())
dlinear.bias = torch.nn.parameter.Parameter (linear.bias.data.double())
dx = x.double()
doutput = dlinear(dx)
doutput1 = dlinear(dx[:1])
doutput2 = dlinear(dx[:(batch_size//2)])

print(10e6*torch.max(torch.abs(doutput[:1] - doutput1)))
print(10e6*torch.max(torch.abs(doutput[:(batch_size//2)] - doutput2)))

ddiffcount = 0
for  n in range (1, batch_size):
    doutputn = dlinear (dx[:n])
    maxdiff = torch.max(torch.abs(doutput[:1] - doutputn[:1])).data[0]
    if  maxdiff != 0.0:
        print  ('n =', n, ', maxdiff =', maxdiff)
        ddiffcount += 1
print ('ddiffcount =', ddiffcount)

Output (set to run on cpu):

0.3.0b0+591e73e
gpu = False
Variable containing:
 2.9802
[torch.FloatTensor of size 1]

Variable containing:
 0
[torch.FloatTensor of size 1]

n = 1 , diff = Variable containing:
 2.9802
[torch.FloatTensor of size 1]

n = 1 , diff = Variable containing:
 0
[torch.FloatTensor of size 1]

n = 2 , diff = Variable containing:
 2.6822
[torch.FloatTensor of size 1]

n = 2 , diff = Variable containing:
 2.3842
[torch.FloatTensor of size 1]

n = 3 , diff = Variable containing:
 2.3842
[torch.FloatTensor of size 1]

n = 3 , diff = Variable containing:
 2.0117
[torch.FloatTensor of size 1]

n = 4 , diff = Variable containing:
 0
[torch.FloatTensor of size 1]

n = 4 , diff = Variable containing:
 2.9802
[torch.FloatTensor of size 1]

n = 1 , maxdiff = 2.980232238769531e-07
n = 2 , maxdiff = 2.682209014892578e-07
n = 3 , maxdiff = 2.384185791015625e-07
diffcount = 3
Variable containing:
 0
[torch.DoubleTensor of size 1]

Variable containing:
 0
[torch.DoubleTensor of size 1]

ddiffcount = 0

My explanation based on the original code is that depending on the amount of computation to be done, different algorithms can be used.
If we focus only on cpu, we have flags like this one that make the choice between using a single thread or OpenMP for multithreaded computations.
I think this is why you see a difference in the original code between the full operation and the one with one sample (mono-core algorithm used), again the other one where if you use half the batch or the whole batch (multi-core algorithm used).

Of course these different algorithms will give rise to different rounding.
A similar argument can be made on GPU where the grid/block sizes are decided based on the input size.

Is that a more satisfying explanation? :slight_smile:

Hello Alban!

Yes, I’ll buy that. Certainly bumping over to an OpenMP algorithm
(or some other size-dependent change of algorithm) would be
expected to change the details of the floating-point result.

(One specific detail doesn’t seem to add up: Your quoted
#define OMP_THRESHOLD 100000 seems to be too large to trigger
OpenMP for Tuan Ahn’s example, at least for my naive estimates of
what totalElements might be.)

Thanks.

K. Frank

Yes, this is most likely not even used in this example. This was just an example of a place where the underlying implementation is dependent on the input size :slight_smile:
Also the underlying libraries we use like MKL and OpenBLAS will have their own thresholds. And the cuda libraries will have their own set of threshold (you can actually play with cudnn algorithm selection for conv by setting torch.backends.cudnn.benchmark and torch.backends.cudnn.deterministic).
I’m afraid it’s beyond my knowledge which thresholds are hit in this particular case.

If you really want to know, you can trace down the call stack to see exactly which function is used. And what are the conditions on that function. I would be interested to know the answer if you do that! :slight_smile: