Gradients not flowing

Hi all, I’m trying to do learning on graphs and I’m facing an issue of my weights in my model not updating. I have simplified the model and it still seems like my weights are not updating. I’m wondering 2 things, (1) am I doing any operations that are breaking the computational graph? Because most of the tensors seem to have backward enabled. (2) Is there a good way to check how gradients are flowing? I’m currently just trying to print the parameters and watch if they get updated.

I even tried my model with a very large learning rate (1e-1) to see if the weights get updated, but no luck.
I deleted quite a few bits of code that did some graph learning as I wanted to simplify my model and isolate where the breakage is, so please pardon some excess unused variables. Here’s a snippet for my code:

class NeuraLBP(torch.nn.Module):
    def __init__(self, d_in_channels, p_in_channels, out_channels, h_size, e_size, n_embed, **kwargs):
        super(NeuraLBP, self).__init__(**kwargs)
        
        # Feature sizes
#         self.in_channels = in_channels
        self.out_channels = out_channels
        self.h_size = h_size
        self.n_embed = n_embed
        
        # Layers
        # Embedding for dummy nodes
        self.embed = torch.nn.Embedding(n_embed, out_channels)
        
        # Translation of node features
        self.dax_lin = torch.nn.Linear(d_in_channels, h_size, bias=False)
        self.pax_lin = torch.nn.Linear(p_in_channels, h_size, bias=False)
        
        # Custom MPNN from another paper
        self.mpnn = MPNN(in_channels=h_size, out_channels=out_channels, edge_feature=e_size, hidden_size=h_size)
        self.mpnn2 = MPNN(in_channels=out_channels, edge_feature=e_size, out_channels=out_channels, hidden_size=h_size)
        self.mpnn3 = MPNN(in_channels=out_channels, edge_feature=e_size, out_channels=out_channels, hidden_size=h_size)        
        # Weight translation layer
        self.weight_transform = torch.nn.Sequential(torch.nn.Linear(out_channels * 2, h_size * 4),
                                                    torch.nn.ReLU(),
                                                    torch.nn.Linear(h_size * 4, h_size * 4),
                                                    torch.nn.ReLU(),
                                                    torch.nn.Linear(h_size * 4, h_size * 2),
                                                    torch.nn.ReLU(),
                                                    torch.nn.Linear(h_size * 2, 1),
                                                    torch.nn.Tanh())

    
    def reset_parameters(self):        
        return
        
    def forward(self,
               pax_feat,
               dax_feat,
               pax_embed,
               assign_edge_index,
               assign_edge_attr,
               real_edges_original,
               real_edge_attr,
               sorted_edges,
               sorted_features,
               n_nodes):
        
        x = torch.cat((self.pax_lin(pax_feat), self.dax_lin(dax_feat)), dim=0)
        out = x
        all_feat = torch.cat((out[:pax_feat.size(0), :], self.embed(pax_embed), out[pax_feat.size(0):, :]), dim=0)
                
        wts = self.weight_transform(torch.cat((all_feat[sorted_edges[:, 0], :], all_feat[sorted_edges[:, 1], :]), dim=1))
        wts = wts.view(n_nodes, n_nodes)

        b_a, b_b = self.iter_simp_min_sum_batch_scatter(wts.unsqueeze(0), wts.t().unsqueeze(0), wts.unsqueeze(0).clone().detach(), 5)

        wts = wts + b_b.permute(0, 2, 1).squeeze(0)

        return wts

  def iter_simp_min_sum_batch_scatter(self, m_alpha_beta, m_beta_alpha, weights, n_iter):
        n = m_alpha_beta.size(1)
        for _ in range(n_iter):
            # Message passing
            beta_alpha_maxes, beta_alpha_indices = torch.topk(m_beta_alpha, 2, dim=1)
            m_alpha_beta_k = weights.permute(0, 2, 1) - beta_alpha_maxes[:, 0, :].unsqueeze(1)
            m_alpha_beta_k = m_alpha_beta_k.scatter_add_(dim=1, index=beta_alpha_indices[:, 0, :].unsqueeze(1),
                                                         src=beta_alpha_maxes[:, 0, :].unsqueeze(1))
            m_alpha_beta_k = m_alpha_beta_k.scatter_add_(dim=1, index=beta_alpha_indices[:, 0, :].unsqueeze(1),
                                                         src=beta_alpha_maxes[:, 1, :].unsqueeze(1) * -1).permute(0, 2,
                                                                                                                  1)

            alpha_beta_maxes, alpha_beta_indices = torch.topk(m_alpha_beta, 2, dim=1)
            m_beta_alpha_k = weights - alpha_beta_maxes[:, 0, :].unsqueeze(1)
            m_beta_alpha_k = m_beta_alpha_k.scatter_add_(dim=1, index=alpha_beta_indices[:, 0, :].unsqueeze(1),
                                                         src=alpha_beta_maxes[:, 0, :].unsqueeze(1))
            m_beta_alpha_k = m_beta_alpha_k.scatter_add_(dim=1, index=alpha_beta_indices[:, 0, :].unsqueeze(1),
                                                         src=alpha_beta_maxes[:, 1, :].unsqueeze(1) * -1).permute(0, 2,
                                                                                                                  1)

            m_alpha_beta = m_alpha_beta_k
            m_beta_alpha = m_beta_alpha_k

            
        return m_alpha_beta, m_beta_alpha

And here’s my training loop:

model = NeuraLBP(d_in_channels=14, p_in_channels=6, out_channels=512, h_size=512, e_size=1, n_embed=1000) 
model.reset_parameters()
model.train()
model = model.to(device)
optim = torch.optim.Adam(model.parameters(), lr=1e-1)
loss_fn = torch.nn.CrossEntropyLoss()

list_loss = []
list_acc = []
for i in range(5):
    for d in d_loader:
        optim.zero_grad()
        pax_feat = d.pax_feat.to(device)
        dax_feat = d.dax_feat.to(device)
        pax_embed = d.pax_embed.to(device)
        assign_edge_index = d.assign_edge_index.to(device)
        assign_edge_attr = d.assign_edge_feat[:, 0].view(-1, 1).to(device) * -1.0
        real_edges_original = d.real_edges_original.to(device)
        real_edge_attr = d.real_edge_attr[:, 0].view(-1, 1).to(device) * -1.0
        data.sorted_edges = d.sorted_edges.type(torch.LongTensor)
        data.sorted_features = d.sorted_features.type(torch.FloatTensor)
        sorted_edges = d.sorted_edges.to(device)
        sorted_features = d.sorted_features.to(device) * -1.0
        n_nodes = d.dax_feat.size(0)
        out = model(pax_feat,
                   dax_feat,
                   pax_embed,
                   assign_edge_index,
                   assign_edge_attr,
                   real_edges_original,
                   real_edge_attr,
                   sorted_edges,
                   sorted_features,
                   n_nodes)
        out = torch.softmax(out, dim=1)

        # Organize the labels - CE labels
        row, col = torch.where(d.lbl_mat > 0)
        labels = torch.zeros((d.lbl_mat.size(0)), dtype=int) - 100
        labels[row] = col
        labels = labels.to(device)

        loss = loss_fn(out, labels)
        list_loss.append(float(loss))
        print(float(loss))
        with torch.no_grad():
            acc = acc_func(out, labels)
            list_acc.append(acc)

        loss.backward()
        optim.step()

Hi,

I did not read all the code, but you use .detach() in there that explicitly prevent gradients from flowing. Could that be the reason?

1 Like

When I pass the parameters into the function, I use self. iter_simp_min_sum_batch_scatter(wts.unsqueeze(0), wts.t().unsqueeze(0), wts.unsqueeze(0).clone().detach(), 5), so technically, only 1 input is detached from the computational graph right?

The parameters correspond to (m_alpha_beta, m_beta_alpha, weights, n_iter), so only weights should be non-learnable, m_alpha_beta and m_beta_alpha should still be part of the graph. Or do I have to explicitly copy it to another variable and detach it to prevent it from affecting the others?

No .detach() is out of place so it won’t change the others.
Is there any place where you get something that does not require gradients and have to set requires_grad again? no right?
In that case, what are the gradients you get? Maybe the gradient of your net are actually 0?

No, there’s no place where i reset the gradient.

How do I check what are the gradients I’m getting? I get a loss value that’s for sure, but I’m not sure what’s the gradients I’m getting and how to check that.

After you zero out your gradient and call .backward() on your loss, each parameter in your model will have a field .grad that contains its gradients. It is this field that is used by the optimizer when you do opt.step().

If that’s the case, I should be printing out the parameter().grad after a loss.backward() call to identify the gradients?

Yes
this would print all of them:

for p in model.parameters():
    print(p.grad)
1 Like

Wow, you are right, the gradients are actually 0. Is there a reason for this? Would the gradients be 0 if there’s a break in the computational graph?

They shoud be None before the first backward pass. If they become 0 after that first backward pass, that means that gradient were propagated.

You will need to check your function but it is quite easy to write functions that have a gradient of 0. If you do step functions for example.
One thing you can use is some_tensor.register_hook() that you can use to get the gradient for any Tensor in your model.
If you use given it the print function, it will print that gradient. If you need something more fancy you can use:

def my_hook(grad):
    print("Gradient for some_tensor")
    print(grad)

I see, so I have to register a hook on the tensor first? That will print the gradient with respect to the tensor right, not the parameters?

Also, I wonder if I could be facing zero gradients because of initializations in my weight matrix?

You have to register the hook before calling backward. And it will print the gradient of the tensor you called register_hook() on.

Also, I wonder if I could be facing zero gradients because of initializations in my weight matrix?

It is possible, if you’re in a local minima where the gradient in all direction is 0.

I see okay, so it means I should be doing something like:

output = model(inputs)

loss = loss_function(output, labels)
model.weight_transform.register_hook()
loss.backward()

print(model.weight_transform.grad)

Hmm okay, I’ll have a look to see if a better initialization can help get gradients. How can I tell if at the last stage of the output, the gradients are likely going to be 0? Is there any way for this? One explanation could be if the input is 0 right?

No, this:

def my_hook(grad):
  print("In the hook")
  print(grad)

output = model(inputs)

loss = loss_function(output, labels)
model.weight_transform.register_hook(my_hook)
loss.backward()  # The hook will be called in the middle of the backward pass

I see ok. Do you have any debugging tips as to why the gradients hit 0? At the start of the training for some initializations, I see some gradients. But that diminishes pretty quickly. Seems to happen if my weight_transform layer outputs from a non-linearity like Tanh() or Sigmoid()

It’s hard to say. But non-linearities can lead to gradients that go to 0. For example if you give as input to a relu layer values that are all negative.
Or if you give to tanh values that are large in absolute value.

Note that if you initially have gradients and then it goes to 0, that means that your model is properly converging a minimum value of the loss function that you provided.

Or if you give to tanh values that are large in absolute value.

Hmm this could suggest why too.

Note that if you initially have gradients and then it goes to 0, that means that your model is properly converging a minimum value of the loss function that you provided.

But sadly, it seems like my accuracy doesn’t improve, and there’s still a loss value. So I think it’s likely some form of local minima? I’m not sure if I can tell also if that the network has sufficient representational power.

It’s hard to say indeed…
But it does sound like you’re getting stuck in a local minimum.

Would agree. I’ll have a better look at it. What’s a normal / expectable maximum gradient like? I’ve been looking at the maximum gradient being propagated backwards to have a rough idea whether there’s any sufficient learning.

The scale of the gradient depends on many things, in particular, the value of the Tensor itself.
Also, the learning rate you use might change completely the scale of the gradients you expect to see.

I am afraid I don’t have an answer for the beyond: whatever works for your model :slight_smile: