How to debug origin of nans in gradient of custom module

Hi everyone,

I’m trying to optimize a network with a custom module and architecture, I’ll attach the code below. The nodes have specific behaviour where positive and negative weights have different effects; I modelled this by using two different weight tensors and using a mask for the gradient.

When running the code we get that the gradient of parameters has a lot of nans. Note that when printing the full gradient matrix only some columns are nans and not the whole matrix, this makes me feel like the nans propagate from a first error
[tensor(0.0267), tensor(0.0106), tensor(nan), tensor(nan), tensor(nan), tensor(nan), tensor(nan), tensor(nan), tensor(nan), tensor(nan), tensor(5.1160e-06), tensor(nan), tensor(4.0687e-06), tensor(nan)]

I have tried using torch.autograd.set_detect_anomaly(True) to debug, and it raises AddBackward0 at the line sum_f_in = torch.matmul(x**2, self.weight_in**2)+self.bias_in**2. But after manually checking there is no overflow error or anything. Even when setting self.weight_in=0 and self.bias_in=0 it still raises the same error.

So I’m lost at what else to try and figure out where the error comes from, any help would be greatly appreciated, thanks!

here is a code to reproduce the error:

import torch.nn as nn
import torch
torch.manual_seed(0)
torch.autograd.set_detect_anomaly(True)
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'

class Node(nn.Module):
    #node of the netwrok, take a matrxi with n features in and return one with m features
    def __init__(self, weight, bias):
        super().__init__()

        self._c = .015436882530272481
        self._fth = 100.
        self._fmax = 4100
        self._in_a = .03051282
        self._in_b = 35.4594017
        self._out_a = 5.419425e-5

        #only weights superior to 0, weight_in are weights<0
        self.weight_ex = nn.Parameter(torch.where(weight>0, weight.abs(), 0))
        self._mask_ex_w = (weight>0).to(DEVICE)#mask for gradient
        self.weight_in = nn.Parameter(torch.where(weight<0, weight.abs(), 0))
        self._mask_in_w = (weight<0).to(DEVICE)
        self.bias_ex =  nn.Parameter(torch.where(bias>0, bias.abs(), 0))
        self._mask_ex_b = (bias>0).to(DEVICE)
        self.bias_in =  nn.Parameter(torch.where(bias<0, bias.abs(), 0))
        self._mask_in_b = (bias<0).to(DEVICE)

        self.set_hook()

    def forward(self, x):
        '''Input of shape (batch_dim, time_dim, 2)'''

        #reducing inputs
        sum_f_ex = torch.matmul(x**2, self.weight_ex**2)+self.bias_ex**2
        sum_f_in = torch.matmul(x**2, self.weight_in**2)+self.bias_in**2
        f_ex = torch.sqrt(sum_f_ex)
        f_in = torch.sqrt(sum_f_in)

        diff = f_ex-self._fth
        xc = self._c*torch.sqrt(torch.where(diff>0, diff, 0))+f_in*self._out_a
        res = torch.where(f_ex>(f_in*self._in_a+self._in_b), xc, 0)
        return res

    def set_hook(self):
        #register hooks for gradient masking
        def _hook_ex_w(grad):
            return grad*self._mask_ex_w
        def _hook_in_w(grad):
            return grad*self._mask_in_w
        def _hook_ex_b(grad):
            return grad*self._mask_ex_b
        def _hook_in_b(grad):
            return  grad*self._mask_in_b

        self.weight_ex.register_hook(_hook_ex_w)
        self.weight_in.register_hook(_hook_in_w)
        self.bias_ex.register_hook(_hook_ex_b)
        self.bias_in.register_hook(_hook_in_b)

class Network(nn.Module):
    #creatwe the network
    def __init__(self, shapes):
        super().__init__()
        self.all_nodes = nn.Sequential()

        for i in range(len(shapes)):
            #build nodes one after the other
                self.all_nodes.append(Node(
                            torch.FloatTensor(*shapes[i]).uniform_(-1500, 1500),
                            torch.FloatTensor(shapes[i][1]).uniform_(-500, 500)))

        #params for sigmoid
        self._decision_line = nn.Parameter(torch.tensor(.7).float())
        self._slope = nn.Parameter(torch.tensor(1.).float())

    def forward(self, x):
        _, time_dim, _ = x.shape
        res = self.all_nodes(x)#propagate through node
        out  =  torch.sum(res, dim=[1, 2])/time_dim#average over time
        return torch.sigmoid(self._slope*(out-self._decision_line))#activation function

if __name__=='__main__':
    #create model and data
    shapes = [[12, 7], [7, 2], [2, 1]]
    data = torch.FloatTensor(32, 210, 12).uniform_(0, 1).to(DEVICE)
    labels = (torch.rand(32)<.5).float().to(DEVICE)
    model = Network(shapes).to(DEVICE)
    loss_fn = nn.BCELoss()
    #run model and propagate loss
    preds = model(data).ravel()
    loss = loss_fn(preds, labels)
    loss.backward()
    print([torch.norm(p.grad.cpu()) for p in model.parameters()])

Hi! I think the cause of your issue is torch.sqrt. Its derivative at +0.0 is inf, at -0.0 is -inf and at negative value is nan. So it’s likely that some tensor provided to torch.sqrt contains some zeros, which results in inf values in the gradient, which then turns into nan values after further backpropagation.

Running your example with a small value (eps) added to the tensor on which you apply torch.sqrt seems to solve the issue:

    def forward(self, x):
        '''Input of shape (batch_dim, time_dim, 2)'''
        eps = 0.000001
        #reducing inputs
        sum_f_ex = torch.matmul(x**2, self.weight_ex**2)+self.bias_ex**2
        sum_f_in = torch.matmul(x**2, self.weight_in**2)+self.bias_in**2
        f_ex = torch.sqrt(sum_f_ex + eps)
        f_in = torch.sqrt(sum_f_in + eps)

        diff = f_ex-self._fth
        xc = self._c*torch.sqrt(torch.where(diff>0, diff, 0) + eps)+f_in*self._out_a
        res = torch.where(f_ex>(f_in*self._in_a+self._in_b), xc, 0)
        return res

Unrelated, but I think the way you create your hooks can create memory issues.

Your hooks hold a reference to self (the nn.Module), which itself holds a reference to the parameters, which hold a reference to their hook. So you have a complex reference cycle that can be hard for the garbage collector to detect, and that can lead to memory (in particular the model params and the masks) not being freed properly even when the module is not referenced anymore.

I’m not 100% this can cause issues in your case, but it’s better to be careful about reference cycles when creating hooks (I had a terrible memory leak because of that recently).

This works, thanks a lot!! I’ll keep it in minds for the hooks and check if it cause any memory leak

1 Like