How to judge a Tensor is in a list?

You could try the following code:

a = torch.randn(3)
b = torch.randn(3)
c = [a, b]
d = [torch.randn(3), torch.randn(3)]

if any([(a == c_).all() for c_ in c]):
    print('a in c')
    
if any([(a == d_).all() for d_ in d]):
    print('a in d')

This code iterates the entries of the lists c and d and compares each entry to the Tensor you would like to check.
Since (a == c_) returns the result for each value, we could call .all on it and finally check if any entry of the list gives a positive result.

EDIT: This approach needs Tensors of the same shape, which might not be useful in some use cases.

EDIT2: This might be a workaround:

d = [torch.randn(5), torch.randn(3)]
if any([(a == d_).all() for d_ in d if a.shape == d_.shape]):
    print('a in d')
1 Like