Floating point precision differences when using torch.matmul on matrix transpose

Hi,

Can someone please explain why there might be differences in floating point precision for matrix multiplication when using the transpose of a matrix vs not using transpose.

For example:

A = torch.tensor([[11., 25.,  3.,  5., 15.,  6., 12., 80.,  1.]], dtype=torch.float32)
B = torch.tensor([[-0.204977914690971375,  0.077419161796569824,  0.189963847398757935,
        -0.241283535957336426, -0.303015649318695068,  0.186268478631973267,
        -0.329424560070037842, -0.248794123530387878,  0.144485712051391602],
      [-0.111751005053520203,  0.235284119844436646, -0.072023868560791016,
        0.053652375936508179, -0.184143990278244019,  0.065914064645767212,
        -0.307873427867889404, -0.253490984439849854, -0.256091356277465820]], dtype=torch.float32)
C = torch.tensor([[-0.204977914690971375, -0.111751005053520203],
    [ 0.077419161796569824,  0.235284119844436646],
    [ 0.189963847398757935, -0.072023868560791016],
    [-0.241283535957336426,  0.053652375936508179],
    [-0.303015649318695068, -0.184143990278244019],
    [ 0.186268478631973267,  0.065914064645767212],
    [-0.329424560070037842, -0.307873427867889404],
    [-0.248794123530387878, -0.253490984439849854],
    [ 0.144485712051391602, -0.256091356277465820]], dtype=torch.float32)

print(torch.matmul(A, B.T))
print(torch.matmul(A, C))

Result:

tensor([[-28.095565795898, -21.891492843628]])
tensor([[-28.095567703247, -21.891494750977]])

Here, B and C contain the same values, they are just different dimensions. We need to take the transpose of B in order to multiply with A, whereas we don’t need to for C.

I’m testing this out on Google Colab, and have tried both CPU and GPU.
Thanks!

Maybe different kernels are used based on the layout and you could profile these with e.g. Nsight Systems or the native profiler.

Because floating point operations results depends on the operands order, i.e. floating point addition is strictly speaking non-associative:

>>> x = (0.1 + 0.2) + 0.3
>>> y = 0.1 + (0.2 + 0.3)
>>> x-y
1.1102230246251565e-16

torch.matmul can perform the operations in different order depending on the inputs layout, which leads to a some discrepancies.