Backprop produces NaNs on a pointcloud deformation loss

I have a neural network which I am trying to use to deform one point set onto the other. I have two loss functions. One takes probabilities of generating one point set from the other. This works fine. The other one is meant to penalize moving close together points too far away from each other, while allowing points far away from each other to move more relative to each other. The functions which go into the losses are here.

import torch
import torch.nn as nn
import torch.nn.functional as F
import scipy.io
import numpy as np
V = torch.autograd.Variable
FT = torch.FloatTensor

def displacement(a,b):
    #This is just the outer product of two point sets
    #Broadcasting is cool
    return b[None, :, :] - a[:,None,:]

def distance(a,b):
    disp = displacement(a,b)
    return ((disp**2).sum(dim=-1)).sqrt()

def gaussian(d, bw): return torch.exp(-0.5*((d/bw))**2) / (bw * np.sqrt(2*math.pi))

def loss_deform(inpt, out):
    #Doesn't work for backprop
    in_dist = distance(inpt, inpt)
    out_dist = distance(out, out)
    diff = my_sigmoid((p_cross - q_cross).abs()) * p_cross
    return -diff.sum()

def loss_prob(out, next_frame):
    #Works just fine
    dist = distance(out, next_frame)
    probs = gaussian(dist, bw=2)
    sum_probs = probs.sum()
    return -sum_probs

When I try backpropping through loss_deform, I get all NaNs after the first step. The problem appears to be coming from p_cross - q_cross, or just backpropping through q_cross. At least it still has problems when I backprop through just the sum of q_cross. I think I’m not understanding something important about what I can backprop through…but I’m not quite sure.
Sample run would be something like this:

net = PointTransform()
optimizer = SGD(net.parameters(), lr=0.001)
#fish is just a point set to test deformation
#I want to deform fish_x to be fish_y
fish_x_T = V(FT(fish_x), requires_grad=False)
fish_y_T = V(FT(fish_y), requires_grad=False)
test = net(fish_x_T)

optimizer.zero_grad()
l = loss_deform(fish_x_T, test)
print(l)
l.backward()
optimizer.step()
optimizer.zero_grad()
test = net(fish_x_T)
l = loss_deform(fish_x_T, test)
print(l)
l.backward()
optimizer.step()

The network is very simple, just:

class PointTransform(nn.Module):
    def __init__(self):
        super(PointTransform, self).__init__()
        #What do I want here
        self.fc1 = nn.Linear(2,10)
        self.bn1 = nn.BatchNorm1d(10)
        self.fc2 = nn.Linear(10,10)
        self.bn2 = nn.BatchNorm1d(10)
        self.fc3 = nn.Linear(10,2)
        
    def forward(self, inpt):
        #couple feed forward layers?
        x = F.relu(self.fc1(inpt))
        x = self.bn1(x)
        x = F.relu(self.fc2(x))
        x = self.bn2(x)
        x = self.fc3(x)
        return x

This should be just a simple gradient descent problem. But my main question is why that loss_deform loss in’t working. And how best might you suggest to implement a outer product point deformation loss which penalizes close points from moving far away from each other while allowing distant points freedom of motion?