# Calculating index of nearest value to a nested tensor

Greetings,

I have a problem where I want to find the nearest indices of a nested tensor based on a flat tensor having arbitrary dimensions.

For example:

``````G = tensor([1, 4, 9])
X = tensor([[2, 3, 7, 9], [8, 6, 7, 5]])
``````

For every element in `X`, I want to find the index of the nearest element in G (minimum distance) as follows:

``````result = nearest_idx(X, G)
print(result) # should return: [[0, 1, 2, 2], [2, 1, 2, 1]]
``````

I am able to solve this problem using list comprehensions, but I want to be able to tensorize this operation (i.e. do it in terms of PyTorch operations for better performance).

Any input is greatly appreciated!

What you can do is resize G by giving it the dimension `(n, m, p)`, where `(n, m)` is the dimension of `X` (`(2, 4)` in your example) and `p` the dimension of `G` (`3` in your example).

Once this is done, `argmin_along_the_p_dimension |X - G|` will give you the result.

``````G = torch.tensor([1, 4, 9]) # (p,)
X = torch.tensor([[2, 3, 7, 9], [8, 6, 7, 5]]) # (n, m)

#n, m, p = 2, 4, 3
#(G.repeat(n, m, 1) - X.unsqueeze(-1)).abs().argmin(dim=-1)
(G.repeat(*X.shape, 1) - X.unsqueeze(-1)).abs().argmin(dim=-1)

"""
tensor([[0, 1, 2, 2],
[2, 1, 2, 1]])
"""
``````
1 Like