Check if tensor elements in a value list

Hi, is there any way to see if elements in a tensor are in a value list?
e.g., a tensor a = ([1, 2, 3, 4, 5]), and I want to check if each element in a is in value list [1, 2, 3]
here we can see that element 4 and 5 are not in value list [1, 2, 3], and element 1 and 2 and 3 are in that value list.

It is actually can be written as: (a == 1) or (a == 2) or (a == 3), while sometimes I don’t want to iterate values in the value list [1, 2, 3] thus I want to do the task without iterating values in the list [1, 2, 3], any function to do that?

4 Likes

I found this response from @ptrblck in 2018 that could be a solution: How to judge a Tensor is in a list?

2 Likes

You can achieve this with a for loop:

>>> sum(a==i for i in vals).bool()
tensor([[ True,  True, False],
        [ True,  True, False]])
1 Like

You can do achieve this way as well

setensor = torch.randint(0, 10, size = (5, 20))
selist = [1, 2, 4, 6]
print(setensor)
semasks = torch.any(torch.stack([torch.eq(setensor, aelem).logical_or_(torch.eq(setensor, aelem)) for aelem in selist], dim=0), dim = 0)
print(semasks)
1 Like

You help me a lot!
Clever solution!