Fast method for checking for inclusion

Hey,

I want to check if some elements of a tensor are included in some other tensor.

I have come up with the following two ways to do this. First some reordering of the tensors:

x = x.view(1, -1)
y = y.view(-1, 1)

Then one way would be to just do this:

return (x == y).nonzero().size(0)

However it’s more cache efficient if the y tensor is large to chuck it like this:

chunks = y.chunk(num_chunks)

And then do a for loop:

count = 0
for chunk in chunks:
     count += (x == chunk).sum()

Both of these methods are too slow for what I want to do. For 1000 to 100000 elements the fastest implementation of both takes about 0.14s and I want it to be under 0.01s. Is there another way to do what I want? I think the == operator is not so efficient as it has to materialize a large array in the meantime.

Thank you very much!

resort to numpy maybe? (intersect1d)

That will make me lose the GPU advantage :slight_smile:

What advantage? AFAIK pytorch does GPU binary search with ops like sort, topk, unique, median, that are not helpful for your task. And you say that brute force is too slow.

Well, unique() can tell you overlap size: (num_unique(a)+num_unique(b)) - num_unique(cat(a,b)). Required sortings may still be too heavy though.

If I use numpy then I won’t be able to run my code on GPU. That’s what I meant. Also intersect1d is not exactly what I want. In case of multiple matches of an element, I want all of them to be returned and not just one of them.

You’d just have to temporarily copy tensor(s) to cpu, something like:

npx = x.detach().cpu().numpy()
npy = numpy_func(npx)
y = torch.from_numpy(npy).to(device=x.device)

overhead of this may be acceptable, YMMV

Check other set/sorting functions there.