Comparsion through torch.eq

Hi, all.

While checking the values of torch tensor, I have found a weird boolean on torch.eq. Distinctly, tensor values between a, b on the captured photo are all the same, but, the booleans are not all the same. How come these the result was coming?

import torch
import torch.nn as nn

h_src = torch.Tensor(10,3,5).uniform_(0,1)
h_t_tgt = torch.Tensor(10,1,5).uniform_(0,1)

model = nn.Linear(5,5, bias=None)
a = model(h_t_tgt)
b = torch.einsum(‘lk,ijk->ijl’, [model.weight, h_t_tgt])

print(torch.eq(a.data, b.data))
print(a)
print(b)

I understood this problem is related with a floating-point calculation. Then, is it hard for torch.eq to be precisely used?

To compare floating point numbers, which might have small errors due to the limited floating point precision, you should use torch.isclose or torch.allclose instead of a direct comparison.

1 Like