Ethan_P
(Ethan)
1
Hey I’m kinda a newbie…
How can I calculate closest tensor in list to another tensor?
Example:
Say I have:
[tensor([1,1,1]),tensor([5,5,5]),tensor([10,10,10])]
and:
tensor([2,2,2])
If would want:
tensor([1,1,1])
How can I achieve this? Thank you
ptrblck
2
You could calculate the distance between the tensors using torch.norm
and use argmin
to get the index corresponding to the smallest distance:
a = torch.stack([torch.tensor([1,1,1]), torch.tensor([5,5,5]), torch.tensor([10,10,10])]).float()
b = torch.tensor([2,2,2]).float()
min_idx = torch.norm(a - b.unsqueeze(0), dim=1).argmin()
print(a[min_idx])
> tensor([1., 1., 1.])
Ethan_P
(Ethan)
3
Thank you! I compared this with my old solution, and this is like 10X faster!