Torch.bmm in batched pairwise distance function causing NaN when training

I have a loss function that requires me to compute a batched pairwise distance. In particular I used the implementation from Batched Pairwise Distance

def batch_pairwise_squared_distances(x, y):
  Modified from         
  Input: x is a bxNxd matrix y is an optional bxMxd matirx                                                             
  Output: dist is a bxNxM matrix where dist[b,i,j] is the square norm between x[b,i,:] and y[b,j,:]
  i.e. dist[i,j] = ||x[b,i,:]-y[b,j,:]||^2                                                         
  x_norm = (x**2).sum(2).view(x.shape[0],x.shape[1],1)
  y_t = y.permute(0,2,1).contiguous()
  y_norm = (y**2).sum(2).view(y.shape[0],1,y.shape[1])
  dist = x_norm + y_norm - 2.0 * torch.bmm(x, y_t)
  dist[dist != dist] = 0 # replace nan values with 0
  return torch.clamp(dist, 0.0, np.inf)

When I try to train my model with this, the weights become NaN after a few iterations. However, when I remove torch.bmm(x, y_t), the model is able to train. Does anyone know what in torch.bmm() can cause this issue to occur?

At first I thought maybe I had to normalize my inputs x,y but that did not make a difference. I also tried using a much lower learning rate but it does not make a difference.

Could you check the shape, min and max values of x and y_t before the Nan is created?

Printing out print(x.shape, x.min(), x.max(), y.shape, y_t.min(), y_t.max()) looks like this:

torch.Size([32, 360, 2]) tensor(0., device='cuda:0') tensor(255., device='cuda:0') torch.Size([32, 360, 2]) tensor(-46.9311, device='cuda:0', grad_fn=<MinBackward1>) tensor(371.8865, device='cuda:0', grad_fn=<MaxBackward1>)
torch.Size([32, 360, 2]) tensor(0., device='cuda:0') tensor(255., device='cuda:0') torch.Size([32, 360, 2]) tensor(-101.9687, device='cuda:0', grad_fn=<MinBackward1>) tensor(349.1602, device='cuda:0', grad_fn=<MaxBackward1>)
torch.Size([11, 360, 2]) tensor(0., device='cuda:0') tensor(255., device='cuda:0') torch.Size([11, 360, 2]) tensor(nan, device='cuda:0', grad_fn=<MinBackward1>) tensor(nan, device='cuda:0', grad_fn=<MaxBackward1>)

I am running this loss function on top of an object detection model, hence when it becomes nan, the region proposal no longer proposes 32 boxes. Hence, the shape at the end is [11,360,2].

Furthermore, I ran torch.cdist which computes the pairwise distance for a single input. I used a for loop to compute a batch of 32. This gives the same value as the batch_pairwise_squared_distances function, however, it does not cause the weights to be NaN. This tells me there is something about torch.bmm that is causing the model to be unstable during training.