Assume that there is a reference two-dimensional array `ref`

and a given vector `x`

. I would like to return the closest vector to `x`

from `ref`

, such that the operation is differentiable.

The solution I currently have, which is not differentiable, is like this:

```
distances = torch.sqrt(torch.sum((reference - x) ** 2, dim=1)) # I could have used something like nn.PairwiseDistance to calculate distances
_, min_index = torch.min(distances)
return reference[min_index]
```

This solution is probably not differentiable because it is using the `argmin`

function. Is there a differentiable way of finding the closest vector?