Find indexes of elements from one tensor that matches in another tensor

Hi!

I have a unique unordered tensor A with shape (batch, N) and another unique tensor B (batch, n) where N >= n.

I am trying to find the indexes of A that matches in B but for each batch. (batch/row wise)

Example

a=torch.tensor([[0,1,2,3,4,5,6,7,8,9],[6,11,1,3,14,15,9,17,18,19]])

b = torch.tensor([[2,4,6,7],[3,9,15,19]])

# weird black magic

>>  tensor([[2,4,6,7],[3,6,5,9]])

I am trying to avoid for loops because I am need to computational efficient (aka fast)

Thanks a lot!!!

It’s possible but you would need to:

  • compare each value of a against b which could be faster but would need to store the intermediate results so would need more memory
  • you would then need to make sure that the number of matches for each “row” is equal in order to be able to create a single output tensor
  • you would need to undo the lexicographical sort created by nonzero

Here is a code snippet showing the approach.
Note the comments starting with !! which show why you would need to undo the sort:

a = torch.tensor([[0,1,2,3,4,5,6,7,8,9],[6,11,1,3,14,15,9,17,18,19]])
b = torch.tensor([[2,4,6,7],[3,9,15,19]])

idx = a.unsqueeze(2) == b.unsqueeze(1)
idx = idx.nonzero() # will sort lexicographically!
print(idx)
# tensor([[0, 2, 0],
#         [0, 4, 1],
#         [0, 6, 2],
#         [0, 7, 3],
#         [1, 3, 0],
#         [1, 5, 2], # !! due to sort !!
#         [1, 6, 1], # !! due to sort !!
#         [1, 9, 3]])

idx_ = idx[:, :2]

# check if number of matches is equal for each "row"
matches_len = idx[:,0].unique(return_counts=True)[1]
if (matches_len == matches_len[0]).all():
    output = idx[:, 1].contiguous().view(-1, matches_len[0])
    
print(output)
# tensor([[2, 4, 6, 7],
#         [3, 5, 6, 9]]) # !! 5 and 6 are sorted !!

# undo sort from nonzero
output = output[torch.arange(output.size(0)).unsqueeze(1), idx[:, 2].view_as(output)]
print(output)
# tensor([[2, 4, 6, 7],
#         [3, 6, 5, 9]])
1 Like

Hi @ptrblck

Just a quick note:

Unsqueezing “in the other direction,” so to speak, avoids the need to
undo the sort:

>>> (a.unsqueeze (1) == b.unsqueeze (2)).nonzero()
tensor([[0, 0, 2],
        [0, 1, 4],
        [0, 2, 6],
        [0, 3, 7],
        [1, 0, 3],
        [1, 1, 6],
        [1, 2, 5],
        [1, 3, 9]])

Best.

K. Frank

2 Likes

Oh, that’s neat! I was just thinking how to get rid of it and now you’ve solved it. Thanks for sharing :slight_smile:

Thanks a lot, guys. It worked exactly as I expected. But it continues being like black magic for me this unsqueeze trick (: slight_smile:

In the end, I implemented the solution of @ptrblck but with the suggestion of @KFrank, and works like a charm.

See ya,