It’s possible but you would need to:
- compare each value of
a
against b
which could be faster but would need to store the intermediate results so would need more memory
- you would then need to make sure that the number of matches for each “row” is equal in order to be able to create a single output tensor
- you would need to undo the lexicographical sort created by
nonzero
Here is a code snippet showing the approach.
Note the comments starting with !!
which show why you would need to undo the sort:
a = torch.tensor([[0,1,2,3,4,5,6,7,8,9],[6,11,1,3,14,15,9,17,18,19]])
b = torch.tensor([[2,4,6,7],[3,9,15,19]])
idx = a.unsqueeze(2) == b.unsqueeze(1)
idx = idx.nonzero() # will sort lexicographically!
print(idx)
# tensor([[0, 2, 0],
# [0, 4, 1],
# [0, 6, 2],
# [0, 7, 3],
# [1, 3, 0],
# [1, 5, 2], # !! due to sort !!
# [1, 6, 1], # !! due to sort !!
# [1, 9, 3]])
idx_ = idx[:, :2]
# check if number of matches is equal for each "row"
matches_len = idx[:,0].unique(return_counts=True)[1]
if (matches_len == matches_len[0]).all():
output = idx[:, 1].contiguous().view(-1, matches_len[0])
print(output)
# tensor([[2, 4, 6, 7],
# [3, 5, 6, 9]]) # !! 5 and 6 are sorted !!
# undo sort from nonzero
output = output[torch.arange(output.size(0)).unsqueeze(1), idx[:, 2].view_as(output)]
print(output)
# tensor([[2, 4, 6, 7],
# [3, 6, 5, 9]])