Among a set of reference vectors, how to return the closest one to a given vector?

Assume that there is a reference two-dimensional array ref and a given vector x. I would like to return the closest vector to x from ref, such that the operation is differentiable.

The solution I currently have, which is not differentiable, is like this:

distances = torch.sqrt(torch.sum((reference - x) ** 2, dim=1)) # I could have used something like nn.PairwiseDistance to calculate distances

_, min_index = torch.min(distances)

return reference[min_index]

This solution is probably not differentiable because it is using the argmin function. Is there a differentiable way of finding the closest vector?

nearest neighbor is not differentiable

If I understand this NIPS 2017 paper (Multiscale Quantization for Fast Similarity Search) correctly, it has some sort of nearest neighbor search (Eq. 2) and claims to be training the model using SGD.