Find indices of one tensor in another

Hello,
I have two large tensors A and B, both have type LongTensor and dimension 1. Both tensors are the results of torch.unique, so they have unique elements, and B is a subset of A (all elements of B are available in A). Is there any efficient way to find indices of tensor B elements in tensor A?

If the tensors are not large or if you have enough memory, you could probably use this broadcasting approach:

a = torch.arange(10)
b = torch.arange(2, 7)[torch.randperm(5)]
print((a.unsqueeze(1) == b).nonzero())
> tensor([[2, 1],
        [3, 3],
        [4, 4],
        [5, 0],
        [6, 2]])

Alternatively, I think you might need to use a loop (or list comprehension):

for b_ in b:
    print((a == b_).nonzero())
> tensor([[5]])
tensor([[2]])
tensor([[6]])
tensor([[3]])
tensor([[4]])

Thank you very much for your response. I have another question similar to this question (relaxing the assumption of bein output of unique function). Imagine I have two tensors as follows:
x=torch.LongTensor([5,3,6,8])
y=torch.LongTensor([2,3,4,6,8,5,7,9])

I want a function that returns the indices of elements in y that are available in x. So the output of this function should be:
tensor([5, 1, 3, 4])
I tried to use the solution in this link, but it seems that it is not working for the case when the tensor lengths are different.

The broadcasting solution that you have given above, works perfectly, I am just looking for a more memory efficient approach (at the same performance).

GIve me warning in pytorch 1.6.0

UserWarning: This overload of nonzero is deprecated:
        nonzero(Tensor input, *, Tensor out)
Consider using one of the following signatures instead:
        nonzero(Tensor input, *, bool as_tuple) (Triggered internally at  /pytorch/torch/csrc/utils/python_arg_parser.cpp:766.)

Is there any better approach without loops?

To get rid of the warning, just pass the as_tuple argument as suggested:

print((a.unsqueeze(1) == b).nonzero(as_tuple=False))
1 Like