Training an LSTM: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

Hi everyone,

I´m training an LSTM and I´m getting the following error:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [256, 1024]], which is output 0 of TBackward, is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

My code is as follows:

  1. Defining the net:
import torch.nn as nn

class LSTM(nn.Module):
    #defining initialization method
    def __init__(self, vocab_size, output_size, embedding_dim, hidden_dim, n_layers, dropout=.5):
        
        #class constructor
        super(LSTM, self).__init__()
        self.vocab_size = vocab_size
        self.output_size = output_size
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        
        #defining the embedding layer
        #produces a lower dimensional representation of the input vector
        self.embed = nn.Embedding(vocab_size, embedding_dim)
        #defining lstm cell
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, n_layers, dropout=dropout, batch_first = True)
        
        self.dropout = nn.Dropout(0.2)
        self.fc1 = nn.Linear(hidden_dim, output_size)
        self.sigmoid = nn.Sigmoid()
    
    #defining forward propagation
    def forward(self, x, hidden):
        
        #getting batch_size from input vector
        batch_size = x.size(0)
        
        #getting on with it
        embeddings = self.embed(x)
        out, ct = self.lstm(embeddings, hidden)
        #stacking output
        lstm_out = out.contiguous().view(-1, self.hidden_dim)
        out = self.dropout(lstm_out)
        out = self.fc1(out)
        
        #horizontally stacking the predictions and getting the last one
        out = out.view(batch_size, -1, self.output_size)
        out = out[:, -1]
        
        return out, ct
        
    #defining method to initialize hidden states
    def init_hidden(self, batch_size):
        
        #getting weights of LSTM class
        weight = next(self.parameters()).data
        
        # initialize hidden state with zero weights, and move to GPU if available
            
        #if (train_on_gpu):
            #hidden = (weight.new(self.n_layers, batch_size, self.hidden_dim).zero_().cuda(),
                  #weight.new(self.n_layers, batch_size, self.hidden_dim).zero_().cuda())
        #else:
        hidden = (weight.new(self.n_layers, batch_size, self.hidden_dim).zero_(),
                  weight.new(self.n_layers, batch_size, self.hidden_dim).zero_())
    
        return hidden
  1. Instantiating the net:
n_epochs = 15
learning_rate = 0.0005
vocab_size = len(vocab_to_int)
output_size = len(vocab_to_int)
embedding_dim = 256
hidden_dim = 256
n_layers = 2
#instantiating lstm
lstm = LSTM(vocab_size, output_size, embedding_dim, hidden_dim, n_layers)

#defining criterion
criterion = nn.CrossEntropyLoss()

#definig optimizer
optimizer = torch.optim.Adam(lstm.parameters(), lr=learning_rate)
  1. Training the net:
for epoch in range(1, n_epochs):
    hidden = lstm.init_hidden(batch_size)

    #now looping through data
    for batch_i, (features, targets) in enumerate(trainloader, 1):
        #watch out for incomplete batches
        print(batch_i)
        n_batches = len(trainloader.dataset) // batch_size
        if (batch_i > n_batches):
            break
        print(features)
        print(features.shape)
        print(targets)
        print(targets.shape)
        #defining a forward pass
        output, hidden = lstm(features, hidden)
        lstm.zero_grad()
        loss = criterion(output, targets)
        loss.backward(retain_graph=True)
        optimizer.step()

Any ideas? Thanks!

1 Like

It seems that it´s fixed by adding this line in my training code:

hidden = tuple([each.data for each in hidden])

The resulting training code is now:

for epoch in range(1, n_epochs):
    hidden = lstm.init_hidden(batch_size)

    #now looping through data
    for batch_i, (features, targets) in enumerate(trainloader, 1):
        #watch out for incomplete batches
        print(batch_i)
        n_batches = len(trainloader.dataset) // batch_size
        if (batch_i > n_batches):
            break
        print(features)
        print(features.shape)
        print(targets)
        print(targets.shape)
        #defining a forward pass
        
        hidden = tuple([each.data for each in hidden])
        
        optimizer.zero_grad()
        output, hidden = lstm(features, hidden)
        loss = criterion(output, targets)
        loss.backward(retain_graph=True)
        optimizer.step()

Can I ask why this worked? I implemented it into my code because I had the sample problem.

1 Like

I had a similar issue and I figured out that in my case, the method that caused the inplace operation was the optimizer.step() call. Here is a short example that produces the error:

import torch
import torch.nn as nn
import torch.nn.functional as f


def main_bptt():
    batch_size = 1
    latent_size = 5
    hidden_size = 16

    lstm = nn.LSTM(input_size=latent_size, hidden_size=hidden_size, batch_first=True)

    hidden = torch.zeros((1, batch_size, hidden_size))
    cell_state = torch.zeros((1, batch_size, hidden_size))

    target = torch.randn(1, 1, hidden_size)

    optimizer = torch.optim.Adam(lstm.parameters())
    optimizer.zero_grad()

    for i in range(100):
        latent_input = torch.randn(1, 1, latent_size)

        output, (hidden, cell_state) = lstm(latent_input, (hidden, cell_state))

        loss = f.mse_loss(output, target)
        loss.backward(retain_graph=True)
        optimizer.step()

if __name__ == "__main__":
    main_bptt()

If you comment out optimizer.step() then the code works just fine.

@sennettm89 I guess the approach from @Antonio_Linares works because the line hidden = tuple([each.data for each in hidden]) effectively zeros the gradients in the hidden and cell state and then I guess the inplace operation in the optimizer does not come to effect anymore. In the example code you could also add this line before the optimizer step: hidden, cell_state = hidden.detach(), cell_state.detach() which does the same. Then the code would run fine but I don’t think it actually computes BPTT.