Check if tensor elements in a value list

Hi, is there any way to see if elements in a tensor are in a value list?
e.g., a tensor a = ([1, 2, 3, 4, 5]), and I want to check if each element in a is in value list [1, 2, 3]
here we can see that element 4 and 5 are not in value list [1, 2, 3], and element 1 and 2 and 3 are in that value list.

It is actually can be written as: (a == 1) or (a == 2) or (a == 3), while sometimes I don’t want to iterate values in the value list [1, 2, 3] thus I want to do the task without iterating values in the list [1, 2, 3], any function to do that?

4 Likes

I found this response from @ptrblck in 2018 that could be a solution: How to judge a Tensor is in a list?

2 Likes

You can achieve this with a for loop:

>>> sum(a==i for i in vals).bool()
tensor([[ True,  True, False],
        [ True,  True, False]])
1 Like

You can do achieve this way as well

setensor = torch.randint(0, 10, size = (5, 20))
selist = [1, 2, 4, 6]
print(setensor)
semasks = torch.any(torch.stack([torch.eq(setensor, aelem).logical_or_(torch.eq(setensor, aelem)) for aelem in selist], dim=0), dim = 0)
print(semasks)
2 Likes

You help me a lot!
Clever solution!

You can also achieve this without an explicit for loop:

t = torch.randint(0, 10, size=( 10, 10))
print(t)
search = torch.tensor([0, 5, 9])
mask = ((search.view(-1, 1) - t.view(-1)).transpose(-1, -2) == 0).sum(dim=-1).view(t.shape) != 0
print(mask.int())  # 1s and 0s for readability =)
print(t[mask].unique())  # matches the search tensor!

We just subtract every value in t by every value in search, if the result is 0 the value is the same!

1 Like