I am trying to understand the discrepancy that happens while performing matrix multiplication in batch.
To summarize I am trying to do the following:
- matmul(matrix, tensor) → Slice the output → sliced output
- slice the input tensor → matmul(matrix, sliced_tensor) → sliced output
I expect the results to match exactly but that is not the case. In detail:
In the first step I perform the matmul between a tensor and a matrix and slice the second item (in fact could be any item) of the output. Here is my code
torch.manual_seed(0)
batch_size = 16
input_size = 64
seqlen = 50
hidden_size = 32
tensor = torch.randn(batch_size, input_size, seqlen)
matrix = torch.randn(hidden_size, input_size)
output = torch.matmul(matrix, tensor)
The output is of shape: batch_size, hidden_dim, seqlen
as expected. I slice the second item of the output batch and print the result below:
print(output[2])
tensor([[ -7.3307, 6.8155, -17.1523, ..., 0.8284, -8.5264, -6.9085],
[ -7.0253, 0.8953, -3.8731, ..., -0.3406, -0.8881, 14.2574],
[ 4.6008, -2.4289, 2.7909, ..., -13.3013, -3.2574, -5.8462],
...,
[ -7.6053, -3.6847, -1.0073, ..., -5.0458, -16.5095, 6.8555],
[ 5.7580, 8.9886, -3.3833, ..., 6.3746, -4.7638, -0.6257],
[ 7.8220, -24.1010, 8.7975, ..., 0.4481, -4.6842, -19.7411]])
In the next step, I slice the second item of the input batch and repeat the same operation and I expect the result to match with what I got above while performing the matrix multiplication in batch
output_2 = torch.matmul(matrix, tensor[2])
The output of the above operation is:
output_2
tensor([[ -7.3307, 6.8155, -17.1523, ..., 0.8284, -8.5264, -6.9085],
[ -7.0253, 0.8953, -3.8731, ..., -0.3406, -0.8881, 14.2574],
[ 4.6008, -2.4289, 2.7909, ..., -13.3013, -3.2574, -5.8462],
...,
[ -7.6053, -3.6847, -1.0073, ..., -5.0458, -16.5095, 6.8555],
[ 5.7580, 8.9886, -3.3833, ..., 6.3746, -4.7638, -0.6257],
[ 7.8220, -24.1010, 8.7975, ..., 0.4481, -4.6842, -19.7411]])
Although the outputs look similar visually, there is a slight discrepancy between them :
matching = ((output[2] == output_2).sum() / output_2.numel()).item()
print(f'Percentage of output matching: {matching}')
Percentage of output matching: 0.9662500023841858
I checked and found the result of torch.allclose
is True. Superficially I understand that this is because of some floating point error or something like that. I have the following questions:
- Is this normal to have a slight discrepancy?
- Can someone please explain why this discrepancy happens with some under the hood explanation or please point me to resources to understand this in depth?
Thank you