If a
is some tensor and b
is a tensor or a number, is there a difference between torch.eq(a, b)
and a == b
?
If not, why does Pytorch API specify the torch.eq
method?
Should I avoid using the ==
operator and use torch.eq
instead?
I think both approaches should yield the same result:
a = torch.arange(10)
b = 2
print(a == b)
print(torch.eq(a, b))
b = torch.arange(10)
print(a == b)
print(torch.eq(a, b))
I guess:
- convenience if you prefer to use the explicit
torch.*
methods, e.g. such astorch.add
instead of+
- to allow users to pass the
out
tensor - I guess to map it easily to the internal
aten::eq
operator (but unsure if addingtorch.eq
actually makes it easier)
OK, I marked ptrblckās answer as solution. Maybe adding a small hint about the equivalence of ==
and torch.eq
would be something worth adding into the documentation?
Actually, ==(equality operator) and torch.eq are different!
You can try
print(b == a)
print(torch.eq(b, a))
with
a = torch.arange(5)
b = 2
Then you will find*(in Python3.11 and PyTorch 2.2)* the equality operator can run while the torch.eq will give TypeError.
However, this is not the end of the story. as b is int
in the above example, if we pass PyTorch tensor instead, both comparison will give the desired result.
a = torch.arange(5)
b = torch.arange(2,3)
print(b == a)
# the output is tensor([False, False, True, False, False])
print(torch.eq(b, a))
# the output is tensor([False, False, True, False, False])
One final note, you can treat a == b
as trying to run both torch.eq(a, b)
and torch.eq(b, a)
.
No difference in normal usage.
But there is a huge difference when you try to do model conversion using torchscript torch.jit.script() or onnx torch.onnx.export()
using signs like >, <, == will result in model conversion failture for torch.onnx.export(), and also somtetimes affect torch.jit.script()
@ptrblck