RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [800, 306]]

Hello,

I’m implementing the bahdanau’s attention for the task of regression and I’ m having this issue.

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [800, 306]] is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

Apparently I’m doing some inplace operation in some tensor that requires gradient, as I read in some related posts. I already did plenty of copy() where I thought I could be changing some tensor, but it didn’t work. Can someone help find the mistake? Thanks,

class LSTM(nn.Module):
   def __init__(self, input_size, hidden_size, num_layers=1, dropout = 0.2):
       super().__init__()
       self.hidden_size = hidden_size
       self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=dropout)
       self.num_layers = num_layers
       
   def forward(self, x, hs):

       out, hs = self.lstm(x, hs)        

       return out, hs
   
   def init_hidden(self):
       return (torch.zeros(self.num_layers , 1, self.hidden_size),
               torch.zeros(self.num_layers , 1, self.hidden_size))
   

class BahdanauDecoder(nn.Module):
   def __init__(self, hidden_size, output_size, n_layers=1, drop_prob=0.1):
       super(BahdanauDecoder, self).__init__()
       self.hidden_size = hidden_size
       self.output_size = output_size
       self.n_layers = n_layers
       self.drop_prob = drop_prob
       self.input_size = input_size #number of machines 

       self.fc_hidden = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
       self.fc_encoder = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
       self.weight = nn.Parameter(torch.FloatTensor(1, hidden_size))
       self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)
       self.dropout = nn.Dropout(self.drop_prob)
       self.lstm = nn.LSTM(self.hidden_size + self.input_size, self.hidden_size, batch_first=True)
       self.linear= nn.Linear(self.hidden_size, self.output_size)

   def forward(self, inputs, hidden, encoder__outputs):
       #inputs are actually the previous decoder outputs
       
       encoder_outputs = encoder__outputs.clone().squeeze()
       
       embedded = inputs[0].clone() #shape (#num_machines, 1)
       
       embedded = self.dropout(embedded)

       # Calculating Alignment Scores
       x = torch.tanh(self.fc_hidden(hidden[0].clone())+self.fc_encoder(encoder_outputs))
       a = self.weight.unsqueeze(2).detach()
       alignment_scores = x.clone().bmm(a)  

       # Softmaxing alignment scores to get Attention weights
       attn_weights = F.softmax(alignment_scores.clone().view(1,-1), dim=1)
       
       # Multiplying the Attention weights with encoder outputs to get the context vector
       b = attn_weights.clone().unsqueeze(0).detach()
       c = encoder_outputs.clone().unsqueeze(0).detach()
       context_vector = torch.bmm(b,
                                c)
       
       output = torch.cat((embedded.clone().view(1,-1), context_vector[0].clone()), 1).unsqueeze(0)
       
       # Passing the concatenated vector as input to the LSTM cell
       output, hidden = self.lstm(output, hidden)

       output = self.linear(output[0].clone())
       
       return output, hidden, attn_weights
   
   def init_in(self):
       return torch.zeros(1 , input_size, 1)
   
   def init_hidden(self):
       return (torch.zeros(1 , 1, hidden_size),  
                        torch.zeros(1 , 1, hidden_size)) 

and the training code

def train(model_encoder, model_decoder, epochs, 
         train_set, train_window, device, 
         valid_data=None, lr=0.001, batch_size =1, 
         print_every=10, loss_fct = 'MSE'):

   if device == 'cuda':
       model_encoder.cuda()
       model_decoder.cuda()
       
   
   if loss_fct == "MSE":
       criterion = nn.MSELoss()

   torch.autograd.set_detect_anomaly(True)
   opt_encod = optim.Adam(model_encoder.parameters(), lr=lr)
   opt_decod  = optim.Adam(model_decoder.parameters(), lr=lr)
   
   train_loss = []
   valid_loss = []
   
   for e in range(epochs):
       
       #initial values
       hs_encoder = model_encoder.init_hidden()
       in_decoder = model_decoder.init_in()
       
       if device == 'cuda':
           hs_encoder = tuple([i.cuda().to(device) for i in hs_encoder]) 
           in_decoder = in_decoder.cuda().to(device)

       t_loss = []
       for x, y in get_batches_dataloader(train_set, train_window, batch_size = batch_size):
           
           if device == 'cuda':
               x = x.cuda().to(device)
               y = y.cuda().to(device)

           opt_encod.zero_grad()
           opt_decod.zero_grad()
           
           # Create batch_size dimension if it doesn't exists
           if (len(x.shape)==2):
               x = x.unsqueeze(0)
               
           out_encoder, hs_encoder = model_encoder(x, hs_encoder)
           hs_decoder = hs_encoder
           out_decoder, hs_decoder, aw = model_decoder(in_decoder, hs_decoder, out_encoder)
           out_decoder = out_decoder.unsqueeze(0)
           in_decoder = out_decoder.detach()


           loss = criterion(out_decoder.view(1,1,-1), y.unsqueeze(0))

           loss.backward(retain_graph=True)
           opt_encod.step()
           opt_decod.step()
           t_loss.append(loss.item())


If you have operation like a+=b replace it by a=a+b.

Besides the inplace op check, could you explain why and if retain_graph=True is needed, as it often points towards a wrong usage (in order to avoid another error)?

You’re right, I had added it as a suggestion of the error I had before:

RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling .backward() or autograd.grad() the first time.

I could fix this inplace operation error that I described before by detaching the input of the last linear layer in the decoder, however I’m not sure about the meaning of dropping this tensor from the gradient computation graph…could you explain me if it makes sense not to propagate the gradient in this tensor?

class BahdanauDecoder(nn.Module): 
   def __init__(self, hidden_size, output_size, n_layers=1, drop_prob=0.1):
       super(BahdanauDecoder, self).__init__()
       self.hidden_size = hidden_size
       self.output_size = output_size
       self.n_layers = n_layers
       self.drop_prob = drop_prob
       self.input_size = input_size #number of machines 

       self.fc_hidden = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
       self.fc_encoder = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
       self.weight = nn.Parameter(torch.FloatTensor(1, hidden_size))
       self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)
       self.dropout = nn.Dropout(self.drop_prob)
       self.lstm = nn.LSTM(self.hidden_size + self.input_size, self.hidden_size, batch_first=True)
       self.classifier = nn.Linear(self.hidden_size, self.output_size)

   def forward(self, inputs, hidden, encoder_outputs):
       #inputs are actually the previous decoder outputs
       encoder_outputs = encoder_outputs.squeeze()
       
       embedded = inputs[0].clone() #shape (#num_machines, 1)
       
       embedded = self.dropout(embedded)

       # Calculating Alignment Scores
       x = torch.tanh(self.fc_hidden(hidden[0])+self.fc_encoder(encoder_outputs))
       alignment_scores = x.bmm(self.weight.unsqueeze(2))  

       # Softmaxing alignment scores to get Attention weights
       attn_weights = F.softmax(alignment_scores.view(1,-1), dim=1)

       # Multiplying the Attention weights with encoder outputs to get the context vector
       context_vector = torch.bmm(attn_weights.unsqueeze(0),
                                encoder_outputs.unsqueeze(0))

       # Concatenating context vector with embedded input word
       output = torch.cat((embedded.view(1,-1), context_vector[0]), 1).unsqueeze(0)
       
       # Passing the concatenated vector as input to the LSTM cell
       output, hidden = self.lstm(output, hidden)

       output = self.classifier(output[0].clone().detach()) #CHANGED HERE

       return output, hidden, attn_weights
   
   def init_in(self):
       return torch.zeros(1 , input_size, 1)
   
   def init_hidden(self):
       return (torch.zeros(1 , 1, hidden_size),  
                        torch.zeros(1 , 1, hidden_size))

The above works, even when taking out the option retain_graph=True from the backward propagation in the train, however I’m not sure why.

Thanks for the update.
No, I don’t think you should detach() this tensor, since the computation graph will be cut at this place and no modules used before the detach() operation was used will get a valid gradient (self.classifier would be the only trained module).
The initial error (without using retain_graph=True) is raised e.g. if you try to calculate the gradients via a second backward() call while the intermediate forward activations were already freed in the first backward() operation.
This is often an issue in RNNs if you are e.g. reusing the hidden states without detaching these (the computation graph would grow in each iteration and the backward() operation would then try to calculate the gradients for the current and the previous iterations).
Given that, I think you should check, if using output, hidden = self.lstm(output, hidden.detach()) would work (and pass output directly to self.classifier again).

Thank you for your answer.
Just to confirm, you’re saying that I should detach the hidden state in the input of the lstm layer from the decoder class, right?
And what is the proper way of detaching the hidden state if it is a tuple? I’ve tried to do it as you suggested but I’m getting the following:

AttributeError: 'tuple' object has no attribute 'detach'

Thank you,

Yes, assuming you want to use the previous “values” of the hidden state, but do not want to extend the computation graph.

You can recreate the tuple via:

hidden = tuple(h.detach() for h in hidden)

Thank you for the answer.

I’ve included

this line just before the use of the lstm layer as so the code becomes the following:


class BahdanauDecoder(nn.Module): 
   def __init__(self, hidden_size, output_size, n_layers=1, drop_prob=0.1):
       super(BahdanauDecoder, self).__init__()
       self.hidden_size = hidden_size
       self.output_size = output_size
       self.n_layers = n_layers
       self.drop_prob = drop_prob
       self.input_size = input_size #number of machines 

       self.fc_hidden = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
       self.fc_encoder = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
       self.weight = nn.Parameter(torch.FloatTensor(1, hidden_size))
       self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)
       self.dropout = nn.Dropout(self.drop_prob)
       self.lstm = nn.LSTM(self.hidden_size + self.input_size, self.hidden_size, batch_first=True)
       self.classifier = nn.Linear(self.hidden_size, self.output_size)

   def forward(self, inputs, hidden, encoder_outputs):
       #inputs are actually the previous decoder outputs
       encoder_outputs = encoder_outputs.squeeze()
       
       embedded = inputs[0].clone() #shape (#num_machines, 1)
       
       embedded = self.dropout(embedded)

       # Calculating Alignment Scores
       x = torch.tanh(self.fc_hidden(hidden[0])+self.fc_encoder(encoder_outputs))
       alignment_scores = x.bmm(self.weight.unsqueeze(2))  

       # Softmaxing alignment scores to get Attention weights
       attn_weights = F.softmax(alignment_scores.view(1,-1), dim=1)

       # Multiplying the Attention weights with encoder outputs to get the context vector
       context_vector = torch.bmm(attn_weights.unsqueeze(0),
                                encoder_outputs.unsqueeze(0))

       # Concatenating context vector with embedded input word
       output = torch.cat((embedded.view(1,-1), context_vector[0]), 1).unsqueeze(0)
       
       # Passing the concatenated vector as input to the LSTM cell
       hidden = tuple(h.detach() for h in hidden) #HERE
       output, hidden = self.lstm(output, hidden)

       output = self.classifier(output[0])

       return output, hidden, attn_weights
   
   def init_in(self):
       return torch.zeros(1 , input_size, 1)
   
   def init_hidden(self):
       return (torch.zeros(1 , 1, hidden_size),  
                        torch.zeros(1 , 1, hidden_size))


However I’m getting the same sort of error as before:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [500, 309]], which is output 0 of TBackward, is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!