Torch.equal does not have the same behavior as ==

Here is a small code snippet demonstrating a case where torch.equal does not behave like the == syntax:

import torch
t1 = torch.tensor([1])
t2 = torch.tensor([[1,2,3]])
print(t1==t2)
# tensor([[ True, False, False]])
print(torch.equal(t1,t2))
# False

I was trying to use torch.equal because the output of == syntax gets interpreted as boolean type by my IDE, and it then tends to warn me when I try to do tensor operations to the result.

What Pytorch function can I use to mimic the == syntax? Maybe I should manually broadcast the tensors before torch.equal?

torch.eq(t1, t2) should work.

1 Like

After a little digging into the documentation, it turns out torch.equal is indeed meant to always output a single boolean:

print(torch.equal(t2,t2))
# True

Instead, I should have used torch.eq:

print(torch.eq(t1,t2))
# tensor([[ True, False, False]])

Although simple enough, I have to say that the role of torch.equal is very counter-intuitive, considering how other functions are defined, such as torch.greater_equal, torch.less_equal, and torch.not_equal.