Fast way to check the elements in a tensor

Let’s say, I have a tensor, torch.tensor([1, 2, 3]) and a set {2, 3}. I want to check whether an element in the tensor is in the set. And return a tensor like this: torch.tensor([0, 1, 1]) because 2 and 3 are in the set.

Is there a fast way to do this? Or I have to right a for loop?

if you dont have a constraint of memory, then you can use a broadcasting trick:

x = torch.tensor([1, 2, 3])
y = torch.tensor([2, 3])

x.view(1, -1).eq(y.view(-1, 1)).sum(0)

3 Likes