Hello,
I noticed that two mathematically equivalent matrix multiplication operations produce different numerical results, with torch.allclose() failing even with relatively loose tolerances. I’d like to understand why this happens and whether this is expected behavior.
import torch
# Setup
n_ens = 5
batch_size = 256
ens_k = 32
d_in = 84
d_out = 512
torch.manual_seed(42)
x = torch.randn(n_ens, batch_size, ens_k, d_in)
w = torch.randn(n_ens, 1, d_in, d_out)
# Method 1: 4D batched matmul with broadcasting
# Shape: (5, 256, 32, 84) @ (5, 1, 84, 512) -> (5, 256, 32, 512)
result_4d = x @ w
# Method 2: Reshape to 3D, matmul, reshape back
x_3d = x.reshape(n_ens, batch_size * ens_k, d_in) # (5, 8192, 84)
w_3d = w.squeeze(1) # (5, 84, 512)
result_3d = (x_3d @ w_3d).reshape(n_ens, batch_size, ens_k, d_out)
# Method 3: einsum
result_einsum = torch.einsum(
'nbki,nio->nbko',
x,
w.squeeze(1)
)
print(f"Shapes match: {result_4d.shape == result_3d.shape}")
print(f"4d vs 3d: torch.equal: {torch.equal(result_4d, result_3d)}")
print(f"4d vs 3d: torch.allclose (default): {torch.allclose(result_4d, result_3d)}")
print(f"4d vs 3d: torch.allclose (atol=1e-6): {torch.allclose(result_4d, result_3d, atol=1e-6)}")
print(f"einsum vs 3d: torch.allclose (default): {torch.allclose(result_einsum, result_3d)}")
print(f"einsum vs 3d: torch.allclose (atol=1e-6): {torch.allclose(result_einsum, result_3d, atol=1e-6)}")
I got result_einsum = result_3d, however result_4d != result_3d.
Thank you in advance!