Which operation detached the tensor from the computation graph?

I am implementing a classification model based on BiLSTM and RoBERTa. It takes sentences as input, computes their embeddings using RoBERTa, then these embeddings are passed to 1st BiLSTM. A restricted self-attention (exactly as presented at section 3.4 of this paper https://arxiv.org/pdf/2010.03138.pdf) is applied at end of the 1sT BiLSTM. The output (1st BiLSTM output concatenated with attention output) is then fed to 2nd BiLSTM. We then stack a 2 layers FeedForward Network at each time-step of the last layer of 2nd BiLSTM. The output layer of this FFN is of size 2, we then apply softmax to get probabilities.

The issue is that when I train the model, calling ‘optimizer.step()’ doesn’t update the model parameters. When I call list(model.parameters())[0].grad I got None. So I am presuming the issue id caused by a tensor that was detached from the computation graph during the forward pass. Can anyone help me out?

Below is the model class definition with the dedicated forward method. I am not including the training code as I believe it’s not the root cause of the issue. Thanks a bunch for your help.

import torch
from torch.nn import Softmax, Linear, Dropout, LSTM


class model(torch.nn.Module):
    def __init__(self, 
                 roberta_version,
                 embeds_size, 
                 lstm_h_dim, 
                 lstm_bidirectional, 
                 lstm_n_layers,
                 lstm_ffn_h_dim, 
                 lstm_ffn_out_dim, 
                 dropout,
                 lstm_rest_att_win_size=2):

        super().__init__()

        # roberta version
        self.roberta_version = roberta_version

        # RoBERTa models
        self.roberta_model = RobertaModel.from_pretrained(roberta_version)

        # embeddings size
        self.embeds_size = embeds_size

        # lstm and ffn dimensions
        self.lstm_h_dim = lstm_h_dim
        self.lstm_bidirectional = lstm_bidirectional
        self.lstm_n_layers = lstm_n_layers
        self.lstm_rest_att_win_size = lstm_rest_att_win_size
        self.lstm_ffn_h_dim = lstm_ffn_h_dim
        self.lstm_ffn_out_dim = lstm_ffn_out_dim

        # dropout value
        self.dropout = dropout
        self.dropout_layer = Dropout(self.dropout)

        # first LSTM
        self.lstm_1 = LSTM(self.embeds_size,
                            self.lstm_h_dim,
                            num_layers=self.lstm_n_layers,
                            bidirectional=self.lstm_bidirectional,
                            batch_first=True,
                            dropout=self.dropout)

        # second LSTM
        self.lstm_2 = LSTM(2*self.lstm_h_dim,
                            self.lstm_h_dim,
                            num_layers=self.lstm_n_layers,
                            bidirectional=self.lstm_bidirectional,
                            batch_first=True,
                            dropout=self.dropout)

        # Restricted attention
        # To get similarity between two sentences
        self.sim = Linear(2*lstm_h_dim, 1)
        # The rest of restricted attention module is in the forward method

        # Last classification fully-connected layers
        self.fc1 = Linear(self.lstm_h_dim, self.lstm_ffn_h_dim) 
        self.relu = ReLU()
        self.fc2 = Linear(self.lstm_ffn_h_dim, self.lstm_ffn_out_dim)
        self.softmax = Softmax(dim=2)



    def forward(self, b_sentences_embeds_feat):
        """
        forward pass

        Args:
            b_sentences_embeds_feat (torch.tensor): tensor of sentences embeds (b for batch) 
                with features (optional) of shape (batch_size, nbr_sentences, embeds_size)
        """
        # Feed to first LSTM
        lstm_out, (lstm_hidden, lstm_cell) = self.lstm_1(b_sentences_embeds_feat)
            # lstm_out of shape (batch_size, nbr_in_sentences, 2*self.lstm_h_dim)

        if self.lstm_bidirectional:
            lstm_out_forward = lstm_out[:, :, :self.lstm_h_dim]
            lstm_out_backward = lstm_out[:, :, self.lstm_h_dim:]
            lstm_out_combined = lstm_out_forward + lstm_out_backward
            # lstm_out_combined of shape (batch_size, nbr_in_sentences, self.lstm_h_dim)

        # Apply restricted attention
        batch_size = len(lstm_out_combined)
        # compute similarities for each sentence with the neighbouring 
        # sentences in the window
        similarities = torch.zeros((batch_size, 
                                    self.nbr_in_sentences, 
                                    self.lstm_rest_att_win_size), requires_grad=True)

        assert(self.lstm_rest_att_win_size%2==0)
        half_win_size = int(self.lstm_rest_att_win_size/2)

        for batch_i in range(batch_size):
            for sen_i in range(self.nbr_in_sentences):
                # Add similarity values of left side of window
                for win_i in range(half_win_size, 0, -1): # i.e half_win_size, half_win_size-1, ..., 1
                    if (sen_i - win_i) >= 0:
                        similarities[batch_i][sen_i][half_win_size - win_i] = \
                                self.sim(torch.cat((lstm_out_combined[batch_i][sen_i], 
                                                    lstm_out_combined[batch_i][sen_i - win_i]), 
                                                    dim=0))
                # Add similarity values of right side of window
                for win_i in range(1, half_win_size+1): # i.e 1, 2, ..., half_win_size
                    if (sen_i + win_i) < self.nbr_in_sentences:
                        similarities[batch_i][sen_i][half_win_size + win_i -1] = \
                                self.sim(torch.cat((lstm_out_combined[batch_i][sen_i], 
                                                    lstm_out_combined[batch_i][sen_i + win_i]), 
                                                    dim=0))
        # apply softmax to get attention weights
        att_weights = similarities
        att_weights = self.softmax(att_weights) # of shape (batch_size, 
                                                #           nbr_in_sentences, 
                                                #           self.lstm_rest_att_win_size)
        # add weight 0 for the core sentence, useful for computation later
        full_att_weights = torch.zeros((batch_size, 
                                        self.nbr_in_sentences, 
                                        1 + self.lstm_rest_att_win_size), requires_grad=True)
        for batch_i in range(batch_size):
            for sen_i in range(self.nbr_in_sentences):
                full_att_weights[batch_i][sen_i][:half_win_size] = \
                                att_weights[batch_i][sen_i][:half_win_size]
                full_att_weights[batch_i][sen_i][half_win_size+1:] = \
                                att_weights[batch_i][sen_i][half_win_size:]

        # define padded_output which is the normal output of previous LSTM with
        # pre and post padding of attention window size for easier multiplication
        # later with attention weights
        padded_output = torch.zeros((batch_size, 
                                    self.nbr_in_sentences + (2 * self.lstm_rest_att_win_size), 
                                    self.lstm_h_dim), requires_grad=True)

        for batch_i in range(batch_size):
            for sen_i in range(self.nbr_in_sentences):
                padded_output[batch_i][self.lstm_rest_att_win_size + sen_i] = \
                                                    lstm_out_combined[batch_i][sen_i]

        # get local context by multiplying the attention weights by
        # the associated sentences hidden states (i.e., embeddings)
        local_contexts = torch.zeros(lstm_out_combined.shape, requires_grad=True)
        # local contexts of shape (batch_size, nbr_in_sentences, self.lstm_h_dim)

        for batch_i in range(batch_size):
            for sen_i in range(self.nbr_in_sentences):
                for win_i in range(1, half_win_size+1):
                    local_contexts[batch_i][sen_i] = \
                            torch.add(local_contexts[batch_i][sen_i],
                                      torch.mul(padded_output[batch_i][sen_i - win_i], 
                                                full_att_weights[batch_i][sen_i][half_win_size - win_i]))
                for win_i in range(1, half_win_size+1):
                    local_contexts[batch_i][sen_i] = \
                            torch.add(local_contexts[batch_i][sen_i],
                                      torch.mul(padded_output[batch_i][sen_i + win_i], 
                                                full_att_weights[batch_i][sen_i][half_win_size + win_i]))

        # Concatenate LSTM output with local contexts for each sentence
        output_with_rest_att = torch.cat((lstm_out_combined, local_contexts), axis=2)
        # output_with_rest_att of shape (batch_size, nbr_in_sentences, 2*self.lstm_h_dim)

        # Feed to second LSTM
        lstm_final_out, (lstm_hidden, lstm_cell) = self.lstm_2(output_with_rest_att)

        # Feed to hidden layer
        hidden_ffn = self.fc1(lstm_out_combined) 
            # of shape (batch_size, nbr_in_sentences, self.lstm_ffn_h_dim)
        # Apply ReLu
        hidden_relu = self.relu(hidden_ffn)
        hidden_drop = self.dropout_layer(hidden_relu)
        # Feed to last classificatoin layer
        logits = self.fc2(hidden_drop) # shape (batch_size, nbr_in_sentences, 2)
        # Apply softmax to get probabilities
        probs = self.softmax(logits) # shape (batch_size, nbr_in_sentences, 2)

        # return probs
        return probs, logits

@ptrblck @tom @fmassa @albanD @smth Can anyone please help? been struggling with that for days!

It’s hard to tell what’s going on without a self-contained example.
Does the loss in the end require gradients?

If not, I would recommend to go through the steps in the forward and print if it has grad.

If it does, you can add backward hooks that print:

def add_detect_backward(name, tensor):
  def bwh(x):
     print("Computing backward of ", name)
  tensor.register_hook(bwh)

a = torch.randn(5, 5, requires_grad=True)
add_detect_backward("a", a)
b = 5 * a
add_detect_backward("b", b)
c = b.sum()
c.backward()

This should give you a good picture of where the gradient reaches and where you drop it accidentally.