Calculate closest tensor in list to another tensor

Hey I’m kinda a newbie…
How can I calculate closest tensor in list to another tensor?

Example:

Say I have:

[tensor([1,1,1]),tensor([5,5,5]),tensor([10,10,10])]

and:

tensor([2,2,2])

If would want:

tensor([1,1,1])

How can I achieve this? Thank you

You could calculate the distance between the tensors using torch.norm and use argmin to get the index corresponding to the smallest distance:

a = torch.stack([torch.tensor([1,1,1]), torch.tensor([5,5,5]), torch.tensor([10,10,10])]).float()
b = torch.tensor([2,2,2]).float()
min_idx = torch.norm(a - b.unsqueeze(0), dim=1).argmin()
print(a[min_idx])
> tensor([1., 1., 1.])

Thank you! I compared this with my old solution, and this is like 10X faster!