Matmul instability

Hey folks, have been haunted by this for some time. Now have an isolated repro. c1 and c2 below should be absolutely equal below but that’s not the case. Can anyone please explain why? The behavior is same on CPU and GPU.

import torch

d1=10
d2=2
d3=50
d4=150

torch.set_printoptions(20)

e=torch.rand((d1,d2,d3))
f=torch.rand((d3,d4))
r1=torch.matmul(e.view(-1,d3),f).view(d1,d2,d4)
r2=torch.matmul(e[:,-1,:].unsqueeze(1).view(-1,d3),f).view(d1,1,d4)

c1=r1[:,-1,:]
c2=r2[:,-1,:]
print(torch.equal(c1,c2))
print(torch.max(torch.abs(c1-c2)))
print(torch.min(torch.abs(c1-c2)))

That (1e-6-ish difference) the numerical precision of floating point. There isn’t much to do about it (unless you want to use double or “enhanced precision techniques” for particularly critical things when you get that to 1e-14ish).

Best regards

Thomas

1 Like

Thanks Thomas. So are you saying that if we do a series of flops in the same order, we’d always get the same result (till 20 digits). But if we change the order of flops, while keeping the process identical logically, the final result would be different?

You lose precision in matmul and could get different results even if you run your operations in the exact same order, here’s an explanation – https://www.nag.com/content/wandering-precision

You could use the following routine to find largest error of a matmul (see this thread for derivation)

def float32_matmul_error(a, b):
    a = a.type(torch.float64)
    b = b.type(torch.float64)
    n = max(a.shape+b.shape)
    u = np.finfo(np.dtype('float32')).eps
    gamma=(n*u)/(1-n*u)
    error=torch.norm(a.flatten())*torch.norm(b.flatten())*gamma
    return error
3 Likes

Hi Amity!

Yes, this is true (for any reasonable definition of “do a series
of flops in the same order,” and excluding perverse hardware
implementations).

Consider this:

Mathematically, addition is commutative, i.e. x + y = y + x.
It is also associative: x + (y + z) = (x + y) + z.

Floating-point addition can be (and should be) commutative. But
there is no reasonable way to make floating-point addition
associative.

1.0 + ((-1.0) + 1.e-20) = 0.0 with floating-point arithmetic,
while (1.0 + (-1.0)) + 1.e-20 = 1.e-20. This is because
(-1.0) + 1.e-20 = (-1.0), as neither floats nor doubles have
enough precision to represent (-1.0) + 1.e-20 as a number
different from -1.0. There’s really no way around this.

So if you change the order of operations in x + y + z, you
get the same answer mathematically, but not necessarily with
floating-point arithmetic.

Hello Yaroslav !

This isn’t correct (at least for what one would normally mean by
“run instructions in the exact same order”), and your nag reference
doesn’t support this. Quoting from your nag reference:

The problem is that by messing with the order of the accumulations, you are quite possibly changing the final result, simply due to rounding differences when working with finite precision computer arithmetic.

They talk about logic at a higher level (in this case the compiler)
choosing to change the order of the floating-point operations,
causing the result to change (at the level of floating-point precision).

Best.

K. Frank

3 Likes

Sure, at some point there’s a change in what processor actually does. The NAG example is interesting because most people assume same same inputs + same binary + single thread = same arithmetic result, but instead the result depends on memory alignment which can change between invocations of the binary. Even harder to guarantee same result once your op starts using multiple cores

Thanks a lot folks for your detailed replies.