Different numerical results between 4D batched matmul and equivalent 3D reshaped matmul

Hello,

I noticed that two mathematically equivalent matrix multiplication operations produce different numerical results, with torch.allclose() failing even with relatively loose tolerances. I’d like to understand why this happens and whether this is expected behavior.

import torch


# Setup
n_ens = 5
batch_size = 256
ens_k = 32
d_in = 84
d_out = 512

torch.manual_seed(42) 
x = torch.randn(n_ens, batch_size, ens_k, d_in)
w = torch.randn(n_ens, 1, d_in, d_out)


# Method 1: 4D batched matmul with broadcasting
# Shape: (5, 256, 32, 84) @ (5, 1, 84, 512) -> (5, 256, 32, 512)
result_4d = x @ w

# Method 2: Reshape to 3D, matmul, reshape back
x_3d = x.reshape(n_ens, batch_size * ens_k, d_in)  # (5, 8192, 84)
w_3d = w.squeeze(1)  # (5, 84, 512)
result_3d = (x_3d @ w_3d).reshape(n_ens, batch_size, ens_k, d_out)


# Method 3: einsum
result_einsum = torch.einsum(
    'nbki,nio->nbko', 
    x, 
    w.squeeze(1)
)
print(f"Shapes match: {result_4d.shape == result_3d.shape}")
print(f"4d vs 3d: torch.equal: {torch.equal(result_4d, result_3d)}")
print(f"4d vs 3d: torch.allclose (default): {torch.allclose(result_4d, result_3d)}")
print(f"4d vs 3d: torch.allclose (atol=1e-6): {torch.allclose(result_4d, result_3d, atol=1e-6)}")
print(f"einsum vs 3d: torch.allclose (default): {torch.allclose(result_einsum, result_3d)}")
print(f"einsum vs 3d: torch.allclose (atol=1e-6): {torch.allclose(result_einsum, result_3d, atol=1e-6)}")

I got result_einsum = result_3d, however result_4d != result_3d.

Thank you in advance!

Hi Duong!

This is due to expected numerical round-off error.

Your various methods end up performing the operations that make up the matmul()s in
different, but mathematically equivalent orders. These different orders lead to slightly
different results due to floating-point round-off errors. Note that in the context of the larger
matmul() computation, individual round-off errors can accumulate. So while atol = 1e-6
is larger than typical single-precision round-off error, it’s not larger than what one might
expect for the accumulated round-off error in the overall matmul() computation.

Try looking at both the element-wise means and maxima of the differences of your various
results. You should see them shrink by maybe seven orders of magnitude if you repeat the
computations in double precision.

Best.

K. Frank

1 Like

Thank you for the clarification.
In terms of runtime, is there any different between Method 2 and Method 3?
I did a timeit test and there was no sinificant difference. However, I’d like to have your confirmation.
Thank you in advance

Hi Duong!

There should be no significant difference between Methods 1, 2, and 3, as they are all
doing the same work. If this part of your computation is critical to your overall performance,
you should time the alternatives (using realistic dimensions) and use the fastest. If this isn’t
critical to your performance, I would use the version that seems stylistically better – easier
to read, more consistent with other parts of your code, etc.

I do recall, many pytorch versions ago, that there was a case where some sort of einsum()
significantly outperformed some sort of matmul(). I attributed this to a performance bug in
matmul() that I recall got fixed.

I also recall a case where some sort of matmul() dramatically outperformed einsum(). In
this case I believe that einsum() wasn’t organizing the computation it was being asked to
do in an efficient manner. I’m pretty sure that einsum() has gotten significantly smarter
since then.

Best.

K. Frank

Hello, thank you very much for your reply. Much appreciated!