Cannot backward propagate after first batch

I am wondering if the grid_sample function is differentiable in torch.nn.functional. I built up a simple network to learn a vector field, then use this vector filed to interpolate with given images. The loss is characterized between interpolated image and another given images. However, after feeding in first batch, the gradient of network becomes to NaN. I do not have any NaN value in my input and not sure what cause this.

Code is attached, the SpatialTransformer is the interpolation function I used to deform the input image,

class Net(nn.Module): 
    def __init__(self):
        super().__init__()
        #Encoding
        self.conv1 = nn.Conv2d(2, 4, 3, stride=1, padding=1,bias = True)
        self.conv1.weight.data.fill_(0.1)
        self.bn = nn.BatchNorm2d(4)
        self.act = nn.LeakyReLU()

        self.conv2 = nn.Conv2d(4, 16, 3, stride=1, padding=1,bias = True)
        nn.init.xavier_uniform_(self.conv2.weight)
        self.bn2 = nn.BatchNorm2d(16)
        
        
        self.conv3 = nn.Conv2d(16, 3, 3, stride=1, padding=1,bias = True)
        nn.init.xavier_uniform_(self.conv3.weight)
        self.bn3 = nn.BatchNorm2d(3)
        
        

    def forward(self, x):
        
        x = self.conv1(x)
        x = self.act(x)
        
        x = self.conv2(x)
        x = self.act(x)
        
        x = self.conv3(x)
        x = self.act(x)
        return x


class SpatialTransformer(nn.Module):
    def __init__(self, size, mode='bilinear'):
        super().__init__()
        self.mode = mode
        # create sampling grid
        vectors = [torch.arange(start = 0, end = s, dtype=torch.float) for s in size]
        #vectors = [torch.arange(0,s) for s in size]
        grids = torch.meshgrid(vectors)
        grid = torch.cat((grids[2], grids[1], grids[0]), dim=0)
        grid = torch.unsqueeze(grid, 0)
        grid = torch.unsqueeze(grid, 2)
        '''4d/5d Identity transformation'''
        grid = grid.type(torch.FloatTensor)
        self.register_buffer('grid', grid)

    def interpolation(self, src, flow):
        new_locs =  self.grid +flow
        shape = flow.shape[2:]
        
        for i in range(len(shape)):
            new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[len(shape)-i-1] - 1) - 0.5)
        new_locs[:, 2,:,:,:] = 0

        if len(shape) == 2:
            new_locs = new_locs.permute(0, 2, 3, 1)
        elif len(shape) == 3:
            new_locs = new_locs.permute(0, 2, 3, 4, 1)

        return nnf.grid_sample(src, new_locs,  mode=self.mode)


#################Network optimization########################
net = Net()
trainloader = torch.utils.data.DataLoader(train, batch_size=para.solver.batch_size, shuffle=True, num_workers=1)
if(para.model.loss == 'L2'):
    criterion = nn.MSELoss()
elif (para.model.loss == 'L1'):
    riterion = nn.L1Loss()
if(para.model.optimizer == 'Adam'):
    optimizer = optim.Adam(net.parameters(), lr= para.solver.lr)
elif (para.model.optimizer == 'SGD'):
    optimizer = optim.SGD(net.parameters(), lr= para.solver.lr, momentum=0.9)
running_loss = 0 
printfreq = 1
sigma = 0.03

transformer = SpatialTransformer([1,100,100])
# ##################Training###################################
for epoch in range(para.solver.epochs):
    total= 0; ave = 0
    for i, data in enumerate(trainloader):
        inputs = data
        outputs = net(inputs) 
        
        b, c, w, h = outputs.shape
        outputs = outputs.permute(0, 3, 1, 2).reshape(b,3,1,100,100)
        source = data[:,0,:,:].reshape(b,1,1,100,100)
        target = data[:,1,:,:].reshape(b,1,1,100,100)
        deformed = transformer.interpolation(target,outputs)
        loss = criterion(deformed, target)
        print('deformed:',torch.max(deformed))
        optimizer.zero_grad()

        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        total+=running_loss

you are zeroing the gradients right before calling backward(), hence they all will be zeros.
put the line: optimizer.zero_grad() above calling to net(inputs) so autograd will keep track of your gradients.

Gradients are accumulated when you call loss.backward(). Thus, calling optimizer.zero_grad() before it is okay, and it is suggested to do it in order to clear gradients from the previous batch.

Thanks! I did put it before net(inputs). However, after first batch feed-in, the weights are still NaN somehow.

I am wondering if there are any illegal operation in “nnf.grid_sample”? After replace this function, the gradients became normal, and NaN is gone.

@ptrblck Could you help me to look at this? After I use “nnf.grid_sample” function for network outputs, it seems the weights won’t update somehow. Not sure where it went wrong. Thanks!

Since you are dealing with invalid gradients as NaNs, I would check the code for operations which might cause Infs or NaNs.
E.g. could you check, if you could possibly divide by zero or a small number in:

new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[len(shape)-i-1] - 1) - 0.5)

which could blow up the activations and might cause problems?

Thanks, Poitr! It helps. That’s exactly where the problem is and I got inf values.