Calculating index of nearest value to a nested tensor

Greetings,

I have a problem where I want to find the nearest indices of a nested tensor based on a flat tensor having arbitrary dimensions.

For example:

G = tensor([1, 4, 9])
X = tensor([[2, 3, 7, 9], [8, 6, 7, 5]])

For every element in X, I want to find the index of the nearest element in G (minimum distance) as follows:

result = nearest_idx(X, G)
print(result) # should return: [[0, 1, 2, 2], [2, 1, 2, 1]] 

I am able to solve this problem using list comprehensions, but I want to be able to tensorize this operation (i.e. do it in terms of PyTorch operations for better performance).

Any input is greatly appreciated!

What you can do is resize G by giving it the dimension (n, m, p), where (n, m) is the dimension of X ((2, 4) in your example) and p the dimension of G (3 in your example).

Once this is done, argmin_along_the_p_dimension |X - G| will give you the result.

G = torch.tensor([1, 4, 9]) # (p,)
X = torch.tensor([[2, 3, 7, 9], [8, 6, 7, 5]]) # (n, m)

#n, m, p = 2, 4, 3
#(G.repeat(n, m, 1) - X.unsqueeze(-1)).abs().argmin(dim=-1)
(G.repeat(*X.shape, 1) - X.unsqueeze(-1)).abs().argmin(dim=-1)

"""
tensor([[0, 1, 2, 2],
        [2, 1, 2, 1]])
"""
1 Like