Pytorch - lstm yields retain_graph error - how do I get around this?

I am training a simple LSTM model however pytorch gives me an error saying that I need to set retain_graph=True. However this takes the model longer to train and I do not think I need to do this.

class SequenceModel(nn.Module):

    def __init__(self):
        super().__init__()
        self.lstm = nn.LSTM(input_size = 3, hidden_size = 3, bidirectional=False)
        self.hidden = (torch.randn(1, 1, 3).double(), torch.randn(1, 1, 3).double())

    def forward(self,x):
        lstm_out, self.hidden = self.lstm(x.view(-1, 1, 3),self.hidden)
        return lstm_out

    def loss(self,logits,labels):
        return F.cross_entropy(logits, labels)

model = SequenceModel()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

model = model.double()

model.train()
epochs = 1000
for epoch in tqdm(range(epochs)):
    optimizer.zero_grad()


    logits = model(inputs)
    logits = logits.reshape(-1,3)
    loss = model.loss(logits,outputs.long())

    loss.backward() 
    optimizer.step()

The error I get is:

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

I do not want to use retain_graph=True because the training takes longer to run. I do not think that my simple LSTM should need the retain_graph=True. What am I doing wrong?

Hi,

The problem is that the hidden layers in your model are shared from one invocation to the next. And so they are all linked.
In particular, because the LSTM module runs the whole forward, you do not need to save the final hidden states:

lstm_out, _ = self.lstm(x.view(-1, 1, 3),self.hidden)
1 Like

Thanks, in that case should I remove self.hidden altogether, i.e.

lstm_out, _ = self.lstm(x.view(-1, 1, 3))

so that it doesn’t take the random self.hidden from the constructor?

I think usually the default of them being all 0 is good yes.
So unless you want to do something fancy, you don’t need to provide them !

Sorry I meant that after every training step, the lstm should taken the most up-to-date hidden layers( h,c) right?

That’s why I was passing the up-to-date self.hidden as an argument for the lstm input. Or am I missing something?

In the most basic for of lstm, IIRC, the hidden layer value for the first iteration should be full of 0.
Since the module you use performs all the iterations in one call. You only ever want to provide a Tensor full of 0s for it.

If you want to train the initial hidden layer value, you can declare them as nn.Parameter and given them as output to the lstm (but still ignore the returned value).

1 Like

But after the first iteration/step of training should the first hidden layer still be all zeroes. I’m guessing not since when we’re training the first hidden layer weight values change. In that case:

On my first training step:

my h0 has changed from all zeroes to some non-zero weights - let’s say h0*.

On my second training step:

how do I specify that I want the updated h0*

i.e. lstm_out, _ = self.lstm(x.view(-1, 1, 3),h0*)

without sharing the hidden layers from one invocation the next.

I’m sure I’m confusing something here - but I just wanted to double check.

If you want to learn the hidden layer initial state, starting from 0:

class SequenceModel(nn.Module):

    def __init__(self):
        super().__init__()
        self.lstm = nn.LSTM(input_size = 3, hidden_size = 3, bidirectional=False)
        self.hidden = nn.ParameterList((nn.Parameter(torch.zeros(1, 1, 3),
                                        nn.Parameter(torch.zeros(1, 1, 3)))

    def forward(self,x):
        lstm_out, _ = self.lstm(x.view(-1, 1, 3),self.hidden)
        return lstm_out

And now it will be picked up by the .parameters() and so the optimizer will take care of updating your hidden state with the rest of the parameters.

3 Likes

Thank you very much. I’m guessing this is the proper way to train an LSTM when you’re training with more than one step/iteration.

Not necessarily. Keeping them as 0 for the first step all the time is valid as well.
And reduces the number of parameters so might help avoid overfitting.

1 Like

Thanks. Learn something new everyday :smile: