Discrepancy in matmul while batching

I am trying to understand the discrepancy that happens while performing matrix multiplication in batch.
To summarize I am trying to do the following:

  1. matmul(matrix, tensor) → Slice the output → sliced output
  2. slice the input tensor → matmul(matrix, sliced_tensor) → sliced output

I expect the results to match exactly but that is not the case. In detail:

In the first step I perform the matmul between a tensor and a matrix and slice the second item (in fact could be any item) of the output. Here is my code

torch.manual_seed(0)
batch_size = 16
input_size = 64
seqlen = 50
hidden_size = 32

tensor = torch.randn(batch_size, input_size, seqlen)
matrix = torch.randn(hidden_size, input_size)
output = torch.matmul(matrix, tensor)

The output is of shape: batch_size, hidden_dim, seqlen as expected. I slice the second item of the output batch and print the result below:

print(output[2])
tensor([[ -7.3307,   6.8155, -17.1523,  ...,   0.8284,  -8.5264,  -6.9085],
        [ -7.0253,   0.8953,  -3.8731,  ...,  -0.3406,  -0.8881,  14.2574],
        [  4.6008,  -2.4289,   2.7909,  ..., -13.3013,  -3.2574,  -5.8462],
        ...,
        [ -7.6053,  -3.6847,  -1.0073,  ...,  -5.0458, -16.5095,   6.8555],
        [  5.7580,   8.9886,  -3.3833,  ...,   6.3746,  -4.7638,  -0.6257],
        [  7.8220, -24.1010,   8.7975,  ...,   0.4481,  -4.6842, -19.7411]])

In the next step, I slice the second item of the input batch and repeat the same operation and I expect the result to match with what I got above while performing the matrix multiplication in batch

output_2 = torch.matmul(matrix, tensor[2])

The output of the above operation is:

output_2
tensor([[ -7.3307,   6.8155, -17.1523,  ...,   0.8284,  -8.5264,  -6.9085],
        [ -7.0253,   0.8953,  -3.8731,  ...,  -0.3406,  -0.8881,  14.2574],
        [  4.6008,  -2.4289,   2.7909,  ..., -13.3013,  -3.2574,  -5.8462],
        ...,
        [ -7.6053,  -3.6847,  -1.0073,  ...,  -5.0458, -16.5095,   6.8555],
        [  5.7580,   8.9886,  -3.3833,  ...,   6.3746,  -4.7638,  -0.6257],
        [  7.8220, -24.1010,   8.7975,  ...,   0.4481,  -4.6842, -19.7411]])

Although the outputs look similar visually, there is a slight discrepancy between them :

matching = ((output[2] == output_2).sum() / output_2.numel()).item()
print(f'Percentage of output matching: {matching}')
Percentage of output matching: 0.9662500023841858

I checked and found the result of torch.allclose is True. Superficially I understand that this is because of some floating point error or something like that. I have the following questions:

  1. Is this normal to have a slight discrepancy?
  2. Can someone please explain why this discrepancy happens with some under the hood explanation or please point me to resources to understand this in depth?

Thank you

1 Like

Hi Vignesh!

Yes, such a discrepancy is to be expected due to floating-point
round-off error.

In short, when you use floating-point arithmetic to perform operations
in different orders that should be mathematically equivalent, you, in
general, expect to get (slightly) different results.

In particular, the associative law no longer holds, in that with
floating-point arithmetic, (a + b) + c != a + (b + c). (The
two expressions can be equal for some values of a, b, and, c,
but, in general, will differ.)

For a start, you can look at the WIkipedia article for round-off error,
and dig into it more deeply, if you like, with Goldberg’s classic paper,
“What Every Computer Scientist Should Know About Floating-Point Arithmetic.”

Best.

K. Frank

2 Likes

Thank you very much for your clear explanation Frank!