Hi,
How can i find nearest elements for all values in a tensor from another tensor?
For example,
a = [2,8,4,1] and b = [3,12,5,7]
then result should be [3,7,5,3]
Thank you
Hi,
How can i find nearest elements for all values in a tensor from another tensor?
For example,
a = [2,8,4,1] and b = [3,12,5,7]
then result should be [3,7,5,3]
Thank you
To find the nearest elements for all values in a tensor a
from another tensor b
in PyTorch, you can use a technique that involves calculating the absolute difference between each element in a
and all elements in b
, and then finding the element in b
that has the minimum difference for each element in a
.
a = torch.tensor([2, 8, 4, 1])
b = torch.tensor([3, 12, 5, 7])
# Calculate the absolute differences.
# the unsqueeze here broadcasts a to each row in b
diff = torch.abs(a.unsqueeze(1) - b)
# Find the indices of the nearest elements
nearest_indices = torch.argmin(diff, dim=1)
# Get the nearest elements from b
nearest_elements = b[nearest_indices]
Hi Smth,
Thank yo for the response. The answer is very clear.
If in a case, a and b are mutidimentinal (say [2,2]), how to go about it?
I have tried it in the following manner,
a = torch.tensor([[2,8],[16,25]])
b = torch.tensor([[23,9],[1,13]])
diff = torch.abs(b.reshape(4,1) - a.flatten())
nearest_indices = torch.argmin(diff, dim=1)
nearest_elements = a.flatten()[nearest_indices]
nearest_elements = nearest_elements.reshape(a.shape)
But this is working only for small tensors. I want to apply it on big tensors for which the memory requirement is high.
Could you please suggets any feasible way…
Thank you,
Prachi