Assume that there is a reference two-dimensional array
ref and a given vector
x. I would like to return the closest vector to
ref, such that the operation is differentiable.
The solution I currently have, which is not differentiable, is like this:
distances = torch.sqrt(torch.sum((reference - x) ** 2, dim=1)) # I could have used something like nn.PairwiseDistance to calculate distances _, min_index = torch.min(distances) return reference[min_index]
This solution is probably not differentiable because it is using the
argmin function. Is there a differentiable way of finding the closest vector?