Graph not resetting between backward passes?

While trying to extend word_language_model example, I’m hitting an error:
Trying to backward through the graph second time, but the buffers have already been freed. Please specify retain_variables=True when calling backward for the first time.

I guess I’m missing something obvious here, but why running the model again doesn’t refill the buffers? I was going to implement training in a similar loop.

Code to reproduce:

import torch as th
import torch.nn as nn
from torch.autograd import Variable

# borrowed from `word_language_model`
class RNNModel(nn.Module):
    """Container module with an encoder, a recurrent module, and a decoder."""

    def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers):
        super(RNNModel, self).__init__()
        self.encoder = nn.Embedding(ntoken, ninp)
        self.rnn = getattr(nn, rnn_type)(ninp, nhid, nlayers, bias=False)
        self.decoder = nn.Linear(nhid, ntoken)


        self.rnn_type = rnn_type
        self.nhid = nhid
        self.nlayers = nlayers

    def init_weights(self):
        initrange = 0.1, initrange), initrange)

    def forward(self, input, hidden):
        emb = self.encoder(input)
        output, hidden = self.rnn(emb, hidden)
        decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2)))
        return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden

    def init_hidden(self, bsz):
        weight = next(self.parameters()).data
        if self.rnn_type == 'LSTM':
            return (Variable(, bsz, self.nhid).zero_()),
                    Variable(, bsz, self.nhid).zero_()))
            return Variable(, bsz, self.nhid).zero_())

vocab_size = 256
batch_size = 64

model = RNNModel('GRU', vocab_size, 100, 100, 3)

optimizer = th.optim.Adam(model.parameters(), lr=1e-2)
criterion = nn.CrossEntropyLoss()

x = Variable(th.LongTensor(50, 64))
x[:] = 1

state = model.init_hidden(batch_size)

print('first pass')
logits, state = model(x, state)
loss = criterion(logits.view(-1, vocab_size), x.view(-1))

print('second pass')
logits, state = model(x, state)
loss = criterion(logits.view(-1, vocab_size), x.view(-1))
loss.backward()   # <-- error occurs here

You need to call state.detach_() . Otherwise, it will still hold on to the graph that created it, i.e. that of the first pass.

Um sorry. This will only work from the next release that’s going to be up soon. For now use repackage_hidden from the example.


Thanks! Didn’t realize that I was reusing old graph with this.
PyTorch is awesome!

