Nearest elements

Hi,

How can i find nearest elements for all values in a tensor from another tensor?
For example,
a = [2,8,4,1] and b = [3,12,5,7]
then result should be [3,7,5,3]

Thank you

1 Like

To find the nearest elements for all values in a tensor a from another tensor b in PyTorch, you can use a technique that involves calculating the absolute difference between each element in a and all elements in b, and then finding the element in b that has the minimum difference for each element in a.

a = torch.tensor([2, 8, 4, 1])
b = torch.tensor([3, 12, 5, 7])

# Calculate the absolute differences. 
# the unsqueeze here broadcasts a to each row in b
diff = torch.abs(a.unsqueeze(1) - b)

# Find the indices of the nearest elements
nearest_indices = torch.argmin(diff, dim=1)

# Get the nearest elements from b
nearest_elements = b[nearest_indices]
1 Like

Hi Smth,

Thank yo for the response. The answer is very clear.

If in a case, a and b are mutidimentinal (say [2,2]), how to go about it?

I have tried it in the following manner,

a = torch.tensor([[2,8],[16,25]])

b = torch.tensor([[23,9],[1,13]])

Calculate the absolute differences.

the unsqueeze here broadcasts a to each row in b

diff = torch.abs(b.reshape(4,1) - a.flatten())

Find the indices of the nearest elements

nearest_indices = torch.argmin(diff, dim=1)

Get the nearest elements from b

nearest_elements = a.flatten()[nearest_indices]

nearest_elements = nearest_elements.reshape(a.shape)

But this is working only for small tensors. I want to apply it on big tensors for which the memory requirement is high.

Could you please suggets any feasible way…

Thank you,
Prachi