Here is a small code snippet demonstrating a case where torch.equal does not behave like the == syntax:
import torch t1 = torch.tensor() t2 = torch.tensor([[1,2,3]]) print(t1==t2) # tensor([[ True, False, False]]) print(torch.equal(t1,t2)) # False
I was trying to use torch.equal because the output of == syntax gets interpreted as boolean type by my IDE, and it then tends to warn me when I try to do tensor operations to the result.
What Pytorch function can I use to mimic the == syntax? Maybe I should manually broadcast the tensors before torch.equal?