(==) operation gives wrong result for same metrics when using torch.bmm

Hi,
I need to detect equality in my tensors. I am using two equivalent ways to multiply 4D metrics, which is element-wise then sum, and torch.bmm. The results are the same, but when checking with the ("==
") operator, they dont seem to be.

import torch
a = torch.rand(3,2,100)
b = torch.randn(3,4,2)
a_expanded = a.unsqueeze(1).expand(-1, b.shape[1], -1, -1)
mul1 = (b.unsqueeze(3) * a_expanded).sum(dim = 2)
mul2 = torch.bmm(b, a)

Now, if we do: mul1==mul2
The result is:

tensor([[[1, 1, 0,  ..., 1, 1, 1],
         [0, 0, 1,  ..., 1, 1, 1],
         [0, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1]],

        [[0, 1, 0,  ..., 1, 1, 1],
         [1, 0, 1,  ..., 1, 1, 1],
         [0, 0, 1,  ..., 1, 1, 1],
         [1, 0, 0,  ..., 1, 1, 1]],

        [[1, 1, 0,  ..., 1, 1, 1],
         [1, 0, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1]]], dtype=torch.uint8)

Examining the elements where it gave 0:

They are same. What is wrong?

Due to floating point precision you might get small errors (usually at ~1e-6), so I would suggest to compare the results with torch.allclose (and specifying the tolerance manually, if necessary).
A direct comparison for floating point numbers is often not the best idea.