Thanks a lot for your suggestion @soulitzer! I tried as you suggested using the optimizer implementation in Updatation of Parameters without using optimizer.step() - #4 by albanD. But I get the following error:
RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.
It seems that A is not a part of the computation graph, and hence torch.autograd.grad is not able to calculate its grads. Is it because I am using ‘with torch.no_grad()’ while implementing my updation step that this is happening?
Here is a working piece of code,
#reproducible code
# reproduce error
import torch
from transformers import BertModel, BertForMaskedLM, BertConfig, EncoderDecoderModel
model1 = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert from pre-trained checkpoints
model2 = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert from pre-trained checkpoints
A=torch.rand(2, requires_grad=True)
optimizer1 = torch.optim.Adam(model1.parameters(), lr=0.0001)
optimizer2 = torch.optim.Adam(model2.parameters(), lr=0.0001)
optimizer3 = torch.optim.SGD([A], lr=0.001)
en_input=torch.tensor([[1,2], [3,4]])
en_masks=torch.tensor([[0,0], [0,0]])
de_output=torch.tensor([[3,1], [4,2]])
de_masks=torch.tensor([[0,0], [0,0]])
lm_labels=torch.tensor([[5,7], [6,8]])
torch.autograd.set_detect_anomaly(True)
def update_function(param, grad, loss, learning_rate):
return param - learning_rate * grad
def train1():
for i in range(2):
#optimizer1.zero_grad()
out = model1(input_ids=en_input, attention_mask=en_masks, decoder_input_ids=de_output,
decoder_attention_mask=de_masks, labels=lm_labels.clone())
prediction_scores = out[1]
predictions = F.log_softmax(prediction_scores, dim=2)
loss1=((predictions.sum() - de_output.sum())*A).sum()
loss1.backward(inputs=list(model1.parameters()), retain_graph=True, create_graph=True)
#optimizer1.step()
#updating weights
with torch.no_grad():
for p in model1.parameters():
if p.grad!=None:
new_val = update_function(p, p.grad, loss1, 0.001)
p.copy_(new_val)
def train2():
for i in range (2):
#optimizer2.zero_grad()
outputs=model1(input_ids=en_input, decoder_input_ids=en_input, output_hidden_states=True, return_dict=True)
predictions = F.log_softmax(outputs.logits, dim=2)
values, new_labels = torch.max(predictions, 2)
output=outputs.decoder_hidden_states[-1]
out=model2(input_ids=en_input, decoder_inputs_embeds=output, labels=new_labels)
prediction_scores = out[1]
predictions = F.log_softmax(prediction_scores, dim=2)
loss2=((predictions.sum() - new_labels.sum())).sum()
loss2.backward(retain_graph=True, create_graph=True)
#optimizer2.step()
with torch.no_grad():
for p in model2.parameters():
if p.grad!=None:
new_val = update_function(p, p.grad, loss2, 0.001)
p.copy_(new_val)
def train3():
optimizer3.zero_grad()
output = model2(input_ids=en_input, attention_mask=en_masks, decoder_input_ids=de_output,
decoder_attention_mask=de_masks, labels=lm_labels.clone())
prediction_scores_ = output[1]
predictions_= F.log_softmax(prediction_scores_, dim=2)
loss3=((predictions_.sum() - de_output.sum())).sum()
A.retain_grad()
A.grad=torch.autograd.grad(loss3, A) # --> error
optimizer3.step() # wt updation
train1()
train2()
train3()