Greetings,
I have a problem where I want to find the nearest indices of a nested tensor based on a flat tensor having arbitrary dimensions.
For example:
G = tensor([1, 4, 9])
X = tensor([[2, 3, 7, 9], [8, 6, 7, 5]])
For every element in X
, I want to find the index of the nearest element in G (minimum distance) as follows:
result = nearest_idx(X, G)
print(result) # should return: [[0, 1, 2, 2], [2, 1, 2, 1]]
I am able to solve this problem using list comprehensions, but I want to be able to tensorize this operation (i.e. do it in terms of PyTorch operations for better performance).
Any input is greatly appreciated!