For every element in one tensor, I want to find the index in another tensor. Currently, I am using the following method:
x = torch.LongTensor([1,2,3,4])
y = torch.LongTensor([2,3,4,1])
# y is the query
def retrieve_indices(x, y):
return torch.nonzero(y[:, None] == x[None, :])[:, 1]
print(retrieve_indices(x,y)) # tensor([1, 2, 3, 0])
For example, the first element 1 in the resuslt indicates the index of y[0] in x is 1.
This method is based on one implementation for Numpy mentioned in Stackoverflow. But this method is memory intensive.
Another answer in the same Stackoverflow post has a better solution for Numpy:
x = np.array([1,2,3,4])
y = np.array([2,3,4,1])
index = np.argsort(x)
sorted_x = x[index]
sorted_index = np.searchsorted(sorted_x, y)
yindex = np.take(index, sorted_index)
print(yindex)
But I could not find PyTorch equivalent of the Numpy method searchsorted
. Someone has implemented searchsorted for PyTorch. However my program is required to rely on official PyTorch and Numpy only.
Is there a better way to achieve this task? Or the retrieve_indices
method is the best solution for PyTorch so far (Note that retrieve_indices
is only for comparing two tensors. )? Let’s assume every element in y is in x. Besides, x and y have the same shape: (batch_size, number of elements).