How to check if tensor is in collection of tensors?

Consider following:

x = torch.zeros((1,))
l = [torch.zeros((2,)), torch.zeros((2,))]

Now I want to check whether x is contained in l. The straightforward way would be

if x in l:
    ...

This does not work since in for lists uses == to check equality, but == is overridden in pytorch to check for entry-wise equality. We can still do it with a “manual” loop and use torch.equal() or is depending on what we want. But that doesn’t seem to be very elegant.

So I’m wondering, what is the recommended way to check for containedness (both in the sense of == and is) in pytorch?

Hi,

No I don’t think there is.
Mostly because, as you mentioned, equal() vs is are two distinct things.

Thank you! For the object-identity case it seems we can use

if x in set(l):
    ...

as this uses id() (contrary to lists where .__eq__() is used).

1 Like