How to judge a Tensor is in a list?

Here is an example which I want to do:

import torch
a = torch.randn(3)
b = torch.randn(3)
c = [a]
if a in c:
    .....
......
if b in c:
    .....

But this code will encounter an issue, which raise

RuntimeError: bool value of non-empty torch.ByteTensor objects is ambiguous

So how can I judge if a Tensor in a list?

Some quick testing on version 0.4.0a0+a3d08de …

>>> import torch
>>> a = torch.randn(3)
>>> b = torch.randn(3)
>>> a in [a]
True
>>> b in [a]
RuntimeError: bool value of Tensor with more than one value is ambiguous

You could try the following code:

a = torch.randn(3)
b = torch.randn(3)
c = [a, b]
d = [torch.randn(3), torch.randn(3)]

if any([(a == c_).all() for c_ in c]):
    print('a in c')
    
if any([(a == d_).all() for d_ in d]):
    print('a in d')

This code iterates the entries of the lists c and d and compares each entry to the Tensor you would like to check.
Since (a == c_) returns the result for each value, we could call .all on it and finally check if any entry of the list gives a positive result.

EDIT: This approach needs Tensors of the same shape, which might not be useful in some use cases.

EDIT2: This might be a workaround:

d = [torch.randn(5), torch.randn(3)]
if any([(a == d_).all() for d_ in d if a.shape == d_.shape]):
    print('a in d')
1 Like

Hi,

If you want to find if a tensor with same content is in the list, then @ptrblck looks like the right solution.
If you want to find if the same tensor is in the list, you can use:

a = torch.randn(5)
d = [a, torch.randn(3)]
if any(a is d_ for d_ in d):
   print('a in d')
# a will be in d here

d = [a.clone(), torch.randn(3)]
if any(a is d_ for d_ in d):
   print('a in d')
# a will NOT be in d here
1 Like

Thanks for you reply, I know how to use for loop to solve this problem, I wonder if there are another way which don’t use for loop to solve this, a more efficient way.

While this is an ugly hack, and isn’t guaranteed to work, it might be quite efficient.

try:
    a in c
    do stuff assuming a is in c
except RuntimeError:
    do stuff assuming a is not in c

Besides isn’t it a bug that pytorch throws an error only when a is not in c?

How about this way

setensor = torch.randint(0, 10, size = (5, 20))
selist = [1, 2, 4, 6]
print(setensor)
semasks = torch.any(torch.stack([torch.eq(setensor, aelem).logical_or_(torch.eq(setensor, aelem)) for aelem in selist], dim=0), dim = 0)
print(semasks)

1 Like

not sure how did you come up with this but you save me a lot of headache! thanks