Find derivative of model's paremeters wrt to a vector

Thanks a lot for your suggestion @soulitzer! I tried as you suggested using the optimizer implementation in Updatation of Parameters without using optimizer.step() - #4 by albanD. But I get the following error:

RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.

It seems that A is not a part of the computation graph, and hence torch.autograd.grad is not able to calculate its grads. Is it because I am using ‘with torch.no_grad()’ while implementing my updation step that this is happening?
Here is a working piece of code,

#reproducible code

# reproduce error
import torch
from transformers import BertModel, BertForMaskedLM, BertConfig, EncoderDecoderModel
model1 = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert from pre-trained checkpoints
model2 = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert from pre-trained checkpoints

A=torch.rand(2, requires_grad=True)
optimizer1 = torch.optim.Adam(model1.parameters(), lr=0.0001)
optimizer2 = torch.optim.Adam(model2.parameters(), lr=0.0001)
optimizer3 = torch.optim.SGD([A], lr=0.001)

en_input=torch.tensor([[1,2], [3,4]])
en_masks=torch.tensor([[0,0], [0,0]])
de_output=torch.tensor([[3,1], [4,2]])
de_masks=torch.tensor([[0,0], [0,0]])
lm_labels=torch.tensor([[5,7], [6,8]])

torch.autograd.set_detect_anomaly(True)

def update_function(param, grad, loss, learning_rate):
  return param - learning_rate * grad

def train1():
  for i in range(2):
    #optimizer1.zero_grad()
    out = model1(input_ids=en_input, attention_mask=en_masks, decoder_input_ids=de_output, 
                        decoder_attention_mask=de_masks, labels=lm_labels.clone())
        
    prediction_scores = out[1]
    predictions = F.log_softmax(prediction_scores, dim=2)
    loss1=((predictions.sum() - de_output.sum())*A).sum()
  
    loss1.backward(inputs=list(model1.parameters()), retain_graph=True,  create_graph=True) 
    #optimizer1.step()
    #updating weights
    with torch.no_grad():
      for p in model1.parameters():
        if p.grad!=None:
          new_val = update_function(p, p.grad, loss1, 0.001)
          p.copy_(new_val)

def train2():
  for i in range (2):
    #optimizer2.zero_grad()
    outputs=model1(input_ids=en_input, decoder_input_ids=en_input, output_hidden_states=True, return_dict=True)
    predictions = F.log_softmax(outputs.logits, dim=2)
    values, new_labels = torch.max(predictions, 2)

    output=outputs.decoder_hidden_states[-1]
    out=model2(input_ids=en_input, decoder_inputs_embeds=output, labels=new_labels)
    prediction_scores = out[1]
    predictions = F.log_softmax(prediction_scores, dim=2)
    loss2=((predictions.sum() - new_labels.sum())).sum()
    
    loss2.backward(retain_graph=True,  create_graph=True)
    #optimizer2.step() 
    with torch.no_grad():
      for p in model2.parameters():
        if p.grad!=None:
          new_val = update_function(p, p.grad, loss2, 0.001)
          p.copy_(new_val)
       
def train3():
  optimizer3.zero_grad()
  output = model2(input_ids=en_input, attention_mask=en_masks, decoder_input_ids=de_output, 
                      decoder_attention_mask=de_masks, labels=lm_labels.clone())
        
  prediction_scores_ = output[1]
  predictions_= F.log_softmax(prediction_scores_, dim=2)
  loss3=((predictions_.sum() - de_output.sum())).sum()
  A.retain_grad()
  A.grad=torch.autograd.grad(loss3, A) # --> error
  optimizer3.step() # wt updation 

 train1()
 train2()
 train3()