Inconsistency between `(A==B).all()` and `A.sum()==B.sum()`

Hi, I’m using Pytorch 1.7.1 and I observed the following:

In [9]: import torch
   ...: q = torch.randn(1, 1000, 128)
   ...: k = torch.randn(1, 1000, 128)
   ...: q1 = q[:, :500, :]
   ...: k1 = k[:, :500, :]
   ...: attn = torch.bmm(q, k.transpose(1,2))
   ...: attn1 = torch.bmm(q1, k1.transpose(1,2))
   ...:
   ...: A = attn[0,:500,:500]
   ...: B = attn1[0,:500,:500]
   ...:
   ...: print(A)
   ...: print(B)
   ...: print((A==B).all())
   ...: print(A.sum() == B.sum())
tensor([[ 13.7018,  -6.4425, -10.8135,  ..., -22.0151,  -7.4277,   1.3627],
        [ -1.8352,   1.0710,  -7.1797,  ...,  -6.5084,  -7.4944,   4.5227],
        [  9.8117, -20.0964,  -1.4925,  ...,   5.0464,  -4.3286,   5.9336],
        ...,
        [ -3.9975,  -3.4578,   2.1641,  ...,   0.7644,   9.4954,   3.1092],
        [-17.6577,  17.8311,  10.5954,  ...,  -0.4581,  10.4234,   1.1901],
        [-11.4489,  15.5743,   9.7824,  ...,   0.8799,  16.4985,  20.6857]])
tensor([[ 13.7018,  -6.4425, -10.8135,  ..., -22.0151,  -7.4277,   1.3627],
        [ -1.8352,   1.0710,  -7.1797,  ...,  -6.5084,  -7.4944,   4.5227],
        [  9.8117, -20.0964,  -1.4925,  ...,   5.0464,  -4.3286,   5.9336],
        ...,
        [ -3.9975,  -3.4578,   2.1641,  ...,   0.7644,   9.4954,   3.1092],
        [-17.6577,  17.8311,  10.5954,  ...,  -0.4581,  10.4234,   1.1901],
        [-11.4489,  15.5743,   9.7824,  ...,   0.8799,  16.4985,  20.6857]])
tensor(True)
tensor(False)

Can someone help explain why the discrepancy in the results?

Many thanks.

You are most likely running into the expected floating point precision limitation.
On my system even the result in A and B don’t match since the underlying matrix multiplication uses different inputs (in their shape) and can thus pick different algorithms which might not yield bitwise identical outputs (they could be deterministic when compared against themselves).

Can you clarify what the problem is? When pytorch compares two tensors (say of float32) using ==, doesn’t it compare exactly 32 bits (no more no less)? Similarly, when it performs sum(), doesn’t it use 32 bits only and not any more?

Yes, all 32 bits are used but the order of operations creates small errors as not all numbers can be represented.
E.g. these two sums should also create the same value, but both results introduce a small error:

x = torch.randn(100, 100)
s1 = x.sum()
s2 = x.sum(0).sum(0)
print((s1 - s2).abs().max())
# tensor(1.5259e-05)

as different algorithms (or generally a different order of operations) can be used internally in both approaches.

Thanks for your explanation. I tried the following, which showed that you are right :grinning::

In [3]: C = A.clone()
   ...: D = B.clone()
   ...:
   ...: print((C==D).all())
   ...: print(C.sum() == D.sum())
   ...:
tensor(True)
tensor(True)