Gradients are None of first model when updating model parameters of two models

I’m trying to train two sequential neural networks in one optimizer. I read that this can be done by defining the optimizer as follows:

optimizer_domain = torch.optim.SGD(list(sentences_model.parameters()) + list(words_model.parameters()),lr=0.001)

This also works for simple models after testing. However, my model is a bit more complex. A schematic overview is shown below:

The code for the model is found below:

criterion = nn.MSELoss()
optimizer_domain = torch.optim.SGD(list(words_model.parameters()) + list(sentences_model.parameters()),lr=0.001)
epochs = 2
costval_shared_d = []
data = Data(df_xy, yi_train)
train_loader=DataLoader(data,batch_size=batch_size_Tanh,shuffle=True)

for j in range(epochs):
    for x,y in train_loader:
        words_model_output = words_model(x,y)

        if (all words of a sentence have passed):
            adjusted_sentence = some_functions(words_model_output)
            train_loader_sentences = DataLoader(adjusted_sentences,batch_size=batch_size_sentences,shuffle=True)
            for x_sentence,y_sentence in train_loader_sentences:
                output_sentences = sentences_model(x_sentence)

                #calculating loss
                cost_shared_d = criterion(y_sentence,y_real)

                print(words_model.param1.grad,sentences_model.grad) #Here the gradients of all parameters of words_model are None
                
                optimizer_domain.zero_grad()
                
                cost_shared_d.backward()

               print(words_model.param1.grad,sentences_model.grad) #Here too grad=None of words_model
                
                optimizer_domain.step()

                costval_shared_d.append(cost_shared_d)

As you can see in the code above I’m first loading a dataset of all (numerical representations of) words and I’m passing them through words_model to decide if the words should be masked in the sentence. Then, if it has been decided which words in the sentence should be masked, the new sentence with masked words is passed on through some functions to obtain the representation of the sentence. Then the sentence is passed onto a train_loader that feeds the sentence to sentences_model, which is the second neural network. This model tries to classify the domain of the sentence.

Then the loss is calculated and the optimizer makes a step. But then I see that only the parameters of the sentences_model are updated and not the parameters of the words_model where the gradient is zero.

I’m not sure how I can also train the first neural network, the words_model, at the same time. Because the input of the second neural network is the output of the first neural network and the sequential network is trained on a classification task that only uses the output of the second network, it is important that those networks are updated simultaneously.

I’m suspecting that I break some sort of a gradient ‘line’ somewhere when I perform the functions on the outputs of the first model. But I’m not finding the underlying cause.

Any thoughts or help would be greatly appreciated.

This might be the case and I would start with the DataLoader as it’s a quite unusual approach to wrap model output activations into a new DataLoader.
Check if the x_sentence tensor is still attached to the computation graph by checking its .grad_fn attribute.

Thanks a lot for your quick reply. Good one, it was also a bit cumbersome.

Therefore, I’ve now put the words in the trainloader of one sentence/review only for experimenting:

data = Data(x_input.astype(np.float32), y_input)
train_loader = DataLoader(data,batch_size=batch_size,shuffle=False)

costval = []
mask_list = []

word_count = 0
epochs = 1
adjusted_reviews = []

domains = [torch.tensor([0,1],dtype=torch.float32)]
optimizer_domain = torch.optim.SGD(list(model_sentences.parameters()) + list(model_words.parameters()), lr=0.001)

reviews = ['Hi, this is a test sentence']
costval_shared_d = []
# adjusted_reviews = []
for j in range(epochs):
    for review in reviews:
        for x,y in train_loader:
            
            y_pred = model_words(x.float(),y[0])
            
            if torch.all(y_pred.eq(torch.tensor([1,0]))):
                new_review[word_count] = '[MASK]'
            word_count = word_count + 1
            
      adjusted_review = " ".join(new_review)

      inputs = tokenizer(adjusted_review, return_tensors="pt")
      outputs = bertmodel(**inputs)
      
      h_cls = outputs.last_hidden_state[0][0]
      y_domain_class = model_sentences(h_cls)
        optimizer_domain.zero_grad()

for domain in domains:
                cost_shared_domain_class = criterion(y_domain_class,domain)
                cost_shared_domain_class.backward()
                optimizer_domain.step()
        
        print('Grad 1st NN (words): ', model.l1.weight.grad)
        
        print('Grad 2nd NN (sentences): ', model_shared_domain.l1.weight.grad)

        costval_shared_d.append(cost_shared_domain_class)

When I change the following piece of code:

y_pred = model_words(x.float(),y[0])
if torch.all(y_pred.eq(torch.tensor([1,0]))):
                new_review[word_count] = '[MASK]'
            word_count = word_count + 1
            
adjusted_review = " ".join(new_review)
inputs = tokenizer(adjusted_review, return_tensors="pt")
outputs = bertmodel(**inputs)
        
h_cls = outputs.last_hidden_state[0][0]
y_domain_class = model_sentences(h_cls)

to (and change the dimensions of the NN):

y_pred = model_words(x.float(),y[0])
h_cls = y_pred
y_domain_class = model_sentences(h_cls)

Then the gradients of both NNs are updating. However, I need to do those transformations in between, otherwise the input of my second NN is bogus.

Any ideas?

It seems the main difference would be:

h_cls = outputs.last_hidden_state[0][0]
# vs.
h_cls = y_pred

Did you check if ouputs.last_hidden_state[0][0] is still attached to the computation graph?

Can I check that like this?

Yes, chekcing the .grad_fn is one way to check if this tensor was created by a differentiable operation.
h_cls seems to be attached to a computation graph.
You would need to check the bertmodel usage next and see if its inputs are somehow attached to y_pred or model_words, but I don’t see how this would be possible as it seems you are using y_pred only in the if condition in the first approach?

Yes, but it seems the computation graph attachment comes from the BertModel as in the following code it shows the attachment:

for x,y in train_loader:
            
            y_pred = model_words(x.float(),y[0])

            tokens_tensor[0][word_count] = torch.where(torch.all(y_pred[0].eq(torch.tensor([1,0]))), tokens_tensor[0][word_count], torch.tensor(103))
            word_count = word_count + 1
          
        outputs = bertmodel(tokens_tensor, segments_tensors)
        h_cls = outputs.last_hidden_state[0][0]
        print(h_cls.grad_fn) #Output: <SelectBackward object at 0x0000022E9700B5C0>
        
        y_domain_class = model_sentences(h_cls.float())

But without the BertModel, in this code snippet, h_cls is not attached:

for x,y in train_loader:
            
            y_pred = model_words(x.float(),y[0])

            tokens_tensor[0][word_count] = torch.where(torch.all(y_pred[0].ne(torch.tensor([1,0]))), tokens_tensor[0][word_count], torch.tensor(103))
            word_count = word_count + 1
    
        h_cls = tokens_tensor
        print(h_cls.grad_fn) # Output: None
        y_domain_class = model_sentences(h_cls.float())

Btw, I replaced the if statement for a where condition and removed the string operations and only work with vector operations now.

As you say, the only ‘information’ I get from y_pred, is the if condition. So a word should be masked if y_pred = [1,0]. Is there any way to work around this and pass on more information to the second NN? I already tried to put all operations between the 2 NNs inside the first NN. But it didn’t make any difference.

In order to pass on the gradients to the second NN, can I only use y_pred and perform non-breaking gradient operations on it? In that case I’m wondering how I could achieve this.

I don’t think that’s the case since you are not passing y_pred to bertmodel.
You are only using it as a condition in torch.where and tokens_tensor is not attached to model_words.

So if I understand correctly, the only way to pass the gradients is by taking the output of the first NN and perform only differentiable modifications to it such that the gradients don’t break?

There is no workaround for when only the condition comes from the first NN?

Yes, this is correct.

Yes, this is also correct. The comparison will create aBoolTensor, which will be detached from the computation graph.
However, even if the result would be a FloatTensor with an attached backward function, the gradients would be zero everywhere but the exact point when the condition flip (where they would be +/-Inf).

1 Like