I have a loss function that requires me to compute a batched pairwise distance. In particular I used the implementation from Batched Pairwise Distance
def batch_pairwise_squared_distances(x, y): ''' Modified from https://discuss.pytorch.org/t/efficient-distance-matrix-computation/9065/3 Input: x is a bxNxd matrix y is an optional bxMxd matirx Output: dist is a bxNxM matrix where dist[b,i,j] is the square norm between x[b,i,:] and y[b,j,:] i.e. dist[i,j] = ||x[b,i,:]-y[b,j,:]||^2 ''' x_norm = (x**2).sum(2).view(x.shape,x.shape,1) y_t = y.permute(0,2,1).contiguous() y_norm = (y**2).sum(2).view(y.shape,1,y.shape) dist = x_norm + y_norm - 2.0 * torch.bmm(x, y_t) dist[dist != dist] = 0 # replace nan values with 0 return torch.clamp(dist, 0.0, np.inf)
When I try to train my model with this, the weights become NaN after a few iterations. However, when I remove
torch.bmm(x, y_t), the model is able to train. Does anyone know what in
torch.bmm() can cause this issue to occur?
At first I thought maybe I had to normalize my inputs
x,y but that did not make a difference. I also tried using a much lower learning rate but it does not make a difference.