# Torch.bmm in batched pairwise distance function causing NaN when training

I have a loss function that requires me to compute a batched pairwise distance. In particular I used the implementation from Batched Pairwise Distance

``````def batch_pairwise_squared_distances(x, y):
'''
Modified from https://discuss.pytorch.org/t/efficient-distance-matrix-computation/9065/3
Input: x is a bxNxd matrix y is an optional bxMxd matirx
Output: dist is a bxNxM matrix where dist[b,i,j] is the square norm between x[b,i,:] and y[b,j,:]
i.e. dist[i,j] = ||x[b,i,:]-y[b,j,:]||^2
'''
x_norm = (x**2).sum(2).view(x.shape,x.shape,1)
y_t = y.permute(0,2,1).contiguous()
y_norm = (y**2).sum(2).view(y.shape,1,y.shape)
dist = x_norm + y_norm - 2.0 * torch.bmm(x, y_t)
dist[dist != dist] = 0 # replace nan values with 0
``````

When I try to train my model with this, the weights become NaN after a few iterations. However, when I remove `torch.bmm(x, y_t)`, the model is able to train. Does anyone know what in `torch.bmm()` can cause this issue to occur?

At first I thought maybe I had to normalize my inputs `x,y` but that did not make a difference. I also tried using a much lower learning rate but it does not make a difference.

Could you check the shape, min and max values of `x` and `y_t` before the Nan is created?

Printing out `print(x.shape, x.min(), x.max(), y.shape, y_t.min(), y_t.max())` looks like this:

``````torch.Size([32, 360, 2]) tensor(0., device='cuda:0') tensor(255., device='cuda:0') torch.Size([32, 360, 2]) tensor(-46.9311, device='cuda:0', grad_fn=<MinBackward1>) tensor(371.8865, device='cuda:0', grad_fn=<MaxBackward1>)
torch.Size([32, 360, 2]) tensor(0., device='cuda:0') tensor(255., device='cuda:0') torch.Size([32, 360, 2]) tensor(-101.9687, device='cuda:0', grad_fn=<MinBackward1>) tensor(349.1602, device='cuda:0', grad_fn=<MaxBackward1>)
torch.Size([11, 360, 2]) tensor(0., device='cuda:0') tensor(255., device='cuda:0') torch.Size([11, 360, 2]) tensor(nan, device='cuda:0', grad_fn=<MinBackward1>) tensor(nan, device='cuda:0', grad_fn=<MaxBackward1>)
``````

I am running this loss function on top of an object detection model, hence when it becomes nan, the region proposal no longer proposes 32 boxes. Hence, the shape at the end is `[11,360,2]`.

Furthermore, I ran `torch.cdist` which computes the pairwise distance for a single input. I used a for loop to compute a batch of 32. This gives the same value as the `batch_pairwise_squared_distances` function, however, it does not cause the weights to be NaN. This tells me there is something about `torch.bmm` that is causing the model to be unstable during training.