Calculate first model's gradients based on another model's loss

Hello,
I am new to Pytorch and I want to implement an algorithm of the following format,

#input1...
output1 = model1(input1)

output = model1(some_input)
#calling output as input2 from now on

output2 = model2(input2)
loss2 = fn(output2)
loss2.backward()

If am not wrong, loss2.backward() should compute the gradients wrt the parameters of model1 as the ‘output’ from model1 is input to model2. But when I check the parameters using the code,

for params in model1.parameters():
   print(params.grad.norm())

The value of the ‘norm’ is constant even after the backward pass, indicating no gradients were calculated for model1’s parameters. I wonder why is this happening? I am using the
hugging face library for the models.
This is my reproducible code

import torch
import torch.nn.functional as F
from transformers import BertModel, BertForMaskedLM, BertConfig, EncoderDecoderModel
model1 = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased') 
model2 = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased') 

#fake inputs 
en_input=torch.tensor([[1,2], [3,4]])
de_output=torch.tensor([[3], [4]])
lm_labels=torch.tensor([[5], [6]])

output1 = model1(input_ids=en_input, decoder_input_ids=de_output, labels=lm_labels)

output = model1.generate(input_ids=en_input, decoder_start_token_id=0) #input2

new_labels=torch.rand(2, output.shape[1])
new_labels=new_labels.long()

output2=model2(input_ids=en_input, decoder_input_ids=output, labels=new_labels)

prediction_scores = output2[1]
predictions = F.log_softmax(prediction_scores, dim=2)
loss2 = (predictions.sum() - output.sum()).sum() # some loss
loss2.backward(retain_graph=True)
#now I check model1's parameters

You are right in the assumption that passing the output of one model to the other would keep the computation graph intact and model1 should get valid gradients as long as the computation graph is not detached.
I don’t know what model1.generate does so could you check, if output has a valid .grad_fn?
If that’s not the case, this would mean that output is detached and Autograd won’t backpropagate through it. On the other hand, if the grad_fn is set, could you check the .grad attributes of model1's parameters before and after the first backward() call (they should be None before and contain valid values after)?

1 Like

Thanks a lot for replying! Turns out model1.generate() produces tensors of type integer and hence they have no ‘grad_fn’ associated with them. Hence I wasn’t able to back-prop till model1.