For every element in one tensor, find the index in another tensor

For every element in one tensor, I want to find the index in another tensor. Currently, I am using the following method:

x = torch.LongTensor([1,2,3,4])
y = torch.LongTensor([2,3,4,1])

# y is the query
def retrieve_indices(x, y):
    return torch.nonzero(y[:, None] == x[None, :])[:, 1]

print(retrieve_indices(x,y))  # tensor([1, 2, 3, 0])

For example, the first element 1 in the resuslt indicates the index of y[0] in x is 1.

This method is based on one implementation for Numpy mentioned in Stackoverflow. But this method is memory intensive.

Another answer in the same Stackoverflow post has a better solution for Numpy:

x = np.array([1,2,3,4])
y = np.array([2,3,4,1])

index = np.argsort(x) 

sorted_x = x[index] 

sorted_index = np.searchsorted(sorted_x, y)

yindex = np.take(index, sorted_index)

But I could not find PyTorch equivalent of the Numpy method searchsorted. Someone has implemented searchsorted for PyTorch. However my program is required to rely on official PyTorch and Numpy only.

Is there a better way to achieve this task? Or the retrieve_indices method is the best solution for PyTorch so far (Note that retrieve_indices is only for comparing two tensors. )? Let’s assume every element in y is in x. Besides, x and y have the same shape: (batch_size, number of elements).


@SimonW Thanks. Unfortunately, it does not work if the inputs are:

x = torch.LongTensor([1,5,3,4])
y = torch.LongTensor([5,3,4,1])

The correct output is tensor([1, 2, 3, 0]), but your method outputs tensor([3, 2, 0, 1]).

Oh sorry, I looked at the requirement wrong. It should be x.argsort()[y.argsort().argsort()]