Hello,
I am trying to find a double derivative using the torch.autograd.grad fucntion. It requires a step where I have to find the double derivative of the model’s parameters wrt to a vector (in this case A). Can someone please guide me how to do so?
# reproduce error
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from transformers import BertModel, BertForMaskedLM, BertConfig, EncoderDecoderModel
model1 = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert from pre-trained checkpoints
model2 = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert from pre-trained checkpoints
A=torch.rand(1, requires_grad=True)
optimizer1 = torch.optim.Adam(model1.parameters(), lr=0.0001)
en_input=torch.tensor([[1,2], [3,4]])
en_masks=torch.tensor([[0,0], [0,0]])
de_output=torch.tensor([[3,1], [4,2]])
de_masks=torch.tensor([[0,0], [0,0]])
lm_labels=torch.tensor([[5,7], [6,8]])
torch.autograd.set_detect_anomaly(True)
def train1():
acc=torch.zeros(1)
for i in range(2):
optimizer1.zero_grad()
out = model1(input_ids=en_input, attention_mask=en_masks, decoder_input_ids=de_output,
decoder_attention_mask=de_masks, labels=lm_labels.clone())
prediction_scores = out[1]
predictions = F.log_softmax(prediction_scores, dim=2)
p=((predictions.sum() - de_output.sum())*A).sum()
p=torch.unsqueeze(p, dim=0)
acc = torch.cat((p,acc))
loss=acc.sum()
loss.backward(inputs=list(model1.parameters()), retain_graph=True, create_graph=True) #calculates gradients
# I want to do something like this. First find derivative wrt to model's weights and then wrt A, so essentially a double derivative.
delL_delWo=(torch.autograd.grad(loss, model1.parameters(), create_graph=True, allow_unused=True)) # model1's weights -> Wo
del2_Loss_delWo_delA=torch.autograd.grad(delL_delWo, A, allow_unused=True) #calculating gradients wrt A i.e del^2 Loss/delWo delA
optimizer1.step() # wt updation
return del2_Loss_delWo_delA
train1()
SInce grad can be implicitly created for scalar inputs, I am confused as to how to find the second derivative i.e del2_Loss_delWo_delA, as delL_delWo will consist of a tuple consisting of 2-D tensors (similar to model1’s parameters.)