Hi, is there any way to see if elements in a tensor are in a value list?
e.g., a tensor a = ([1, 2, 3, 4, 5]), and I want to check if each element in a is in value list [1, 2, 3]
here we can see that element 4 and 5 are not in value list [1, 2, 3], and element 1 and 2 and 3 are in that value list.
It is actually can be written as: (a == 1) or (a == 2) or (a == 3), while sometimes I don’t want to iterate values in the value list [1, 2, 3] thus I want to do the task without iterating values in the list [1, 2, 3], any function to do that?
I found this response from @ptrblck in 2018 that could be a solution: How to judge a Tensor is in a list?
You can achieve this with a for loop:
>>> sum(a==i for i in vals).bool()
tensor([[ True, True, False],
[ True, True, False]])
You can do achieve this way as well
setensor = torch.randint(0, 10, size = (5, 20))
selist = [1, 2, 4, 6]
semasks = torch.any(torch.stack([torch.eq(setensor, aelem).logical_or_(torch.eq(setensor, aelem)) for aelem in selist], dim=0), dim = 0)
You help me a lot!