Torch.bmm()'s performance is a little bit strange to me

Hi all,

Here is my experiment

input = torch.rand(3,10,6)
#method1
as_matrix = torch.bmm(input, input.contiguous().transpose(2,1))

# method2
one_by_one = []
for i in range(input.size(1)):
    query = input[:, i:i+1, :]
    temp = torch.bmm(query, input.contiguous().transpose(2,1))
    one_by_one.append(temp)
one_by_one = torch.cat(one_by_one, dim=1)

In my understanding, the difference between method1 and method2 is how to interpret input. While method1 consider input as a matrix, method2 consider input as a list of matrix being concatenated together alone the second dimension.

My expectation is the outputs should be exactly the same.

But the fact is they are slightly different. Can some help me understand why they are different? OR, this difference is so small that I do not need to worry about it.
Thanks

as_matrix - one_by_one

# output: 

tensor([[[ 0.0000e+00,  0.0000e+00,  5.9605e-08,  0.0000e+00, -1.1921e-07,
           0.0000e+00,  5.9605e-08,  0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00, -2.3842e-07,  0.0000e+00,  0.0000e+00,  1.1921e-07],
         [ 5.9605e-08,  0.0000e+00,  0.0000e+00,  0.0000e+00,  1.1921e-07,
          -5.9605e-08,  0.0000e+00,  0.0000e+00,  5.9605e-08,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -1.1921e-07,
          -5.9605e-08,  0.0000e+00,  1.1921e-07,  0.0000e+00,  0.0000e+00],
         [-1.1921e-07,  0.0000e+00,  1.1921e-07, -1.1921e-07,  2.3842e-07,
           5.9605e-08,  0.0000e+00,  1.1921e-07,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00, -5.9605e-08, -5.9605e-08,  5.9605e-08,
           5.9605e-08,  0.0000e+00,  0.0000e+00,  5.9605e-08,  1.1921e-07],
         [ 5.9605e-08, -2.3842e-07,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  1.1921e-07,  1.1921e-07,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  5.9605e-08,  0.0000e+00,  0.0000e+00,
           5.9605e-08,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  1.1921e-07,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           1.1921e-07,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00]],

        [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           5.9605e-08,  0.0000e+00, -1.1921e-07,  0.0000e+00, -1.1921e-07],
         [ 0.0000e+00,  2.3842e-07,  0.0000e+00,  0.0000e+00,  5.9605e-08,
           1.1921e-07,  0.0000e+00,  0.0000e+00,  2.3842e-07,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  1.1921e-07, -5.9605e-08,
           1.1921e-07, -1.1921e-07,  0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  1.1921e-07,  0.0000e+00,  0.0000e+00,
           1.1921e-07,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  5.9605e-08, -5.9605e-08,  0.0000e+00, -5.9605e-08,
           0.0000e+00,  0.0000e+00, -1.1921e-07,  0.0000e+00,  0.0000e+00],
         [ 5.9605e-08,  1.1921e-07,  1.1921e-07,  1.1921e-07,  0.0000e+00,
           0.0000e+00, -1.1921e-07, -1.1921e-07,  1.1921e-07,  1.1921e-07],
         [ 0.0000e+00,  0.0000e+00, -1.1921e-07,  0.0000e+00,  0.0000e+00,
          -1.1921e-07,  2.3842e-07, -2.3842e-07,  0.0000e+00,  0.0000e+00],
         [-1.1921e-07,  0.0000e+00,  0.0000e+00,  0.0000e+00, -1.1921e-07,
          -1.1921e-07, -2.3842e-07,  0.0000e+00,  2.3842e-07,  0.0000e+00],
         [ 0.0000e+00,  2.3842e-07,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           1.1921e-07,  0.0000e+00,  2.3842e-07,  0.0000e+00,  0.0000e+00],
         [-1.1921e-07,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           1.1921e-07,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00]],

        [[-2.3842e-07,  0.0000e+00, -2.3842e-07,  2.3842e-07, -1.1921e-07,
           0.0000e+00,  1.1921e-07,  0.0000e+00, -1.1921e-07,  2.3842e-07],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00, -1.1921e-07,  0.0000e+00],
         [-2.3842e-07,  0.0000e+00,  2.3842e-07,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  2.3842e-07,  0.0000e+00,  0.0000e+00],
         [ 2.3842e-07,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          -1.1921e-07,  0.0000e+00,  1.1921e-07,  0.0000e+00,  0.0000e+00],
         [-1.1921e-07,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00, -1.1921e-07,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  1.1921e-07],
         [ 1.1921e-07,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00, -1.1921e-07,  1.1921e-07,  0.0000e+00, -1.1921e-07],
         [ 0.0000e+00,  0.0000e+00,  2.3842e-07,  1.1921e-07,  0.0000e+00,
           0.0000e+00,  1.1921e-07,  0.0000e+00,  1.1921e-07, -1.1921e-07],
         [-1.1921e-07, -1.1921e-07,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           0.0000e+00,  0.0000e+00,  1.1921e-07,  1.1921e-07,  5.9605e-08],
         [ 2.3842e-07,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           1.1921e-07, -1.1921e-07, -1.1921e-07,  5.9605e-08,  1.1921e-07]]])

You shouldn’t worry about the small difference as the errors are in the range of the limited floating point precision.
These small absolute errors are created by e.g. different ordering of your operations as seen here:


x = torch.randn(10, 10, 10)
s1 = x.sum()
s2 = x.sum(0).sum(0).sum(0)
print((s1 - s2).abs().max())
> tensor(2.3842e-06)
1 Like

Thank you soooo much ptrblck. Now I am more comfortable on my unittest results.
Just out of curiosity, why does the ordering of operations cause this difference? I.E. something to do with cudnn’s algorithm of matrix computation? I am just curios. I have zero experience with cudnn coding.

This issue is not specific to cudnn, but to floating point arithmetic.
The linked Wikipedia article describes some issues, where the precision is lost.

Basically, neither result might be the “true” value in by small example and both results come close to the mathematically correct answer.

1 Like

Thanks ptrblck. You are the best!!!

1 Like