Consider following:
x = torch.zeros((1,))
l = [torch.zeros((2,)), torch.zeros((2,))]
Now I want to check whether x
is contained in l
. The straightforward way would be
if x in l:
...
This does not work since in
for lists uses ==
to check equality, but ==
is overridden in pytorch to check for entry-wise equality. We can still do it with a “manual” loop and use torch.equal()
or is
depending on what we want. But that doesn’t seem to be very elegant.
So I’m wondering, what is the recommended way to check for containedness (both in the sense of ==
and is
) in pytorch?