RNN Batch Training: Backward pass, retain_graph?

First post here, forgive me if I’m breaking any conventions…

I’m trying to train a simple LSTM on time series data where the input (x) is 2-dimensional and the output (y) is 1-dimensional. I’ve set the sequence length at 60 and the batch size at 30 so that x is of size [60,30,2] and y is of size [60,30,1]. Each sequence is fed through the model one timestamp at a time, and the resulting 60 losses are averaged. I am hoping to backpropagate the gradient of this loss to do a parameter update.

for i in range(num_epochs):
    model.hidden = model.init_hidden()
    for j in range(data.n_batches):
        x, y = data.next_batch(0)
        lst = torch.zeros(1, requires_grad=True)
    
        for t in range(x.shape[0]):
           y_pred = model(x[t:t+1,:,:])
            lst = lst + loss_fn(y_pred, y[t].view(-1))
    
        lst /= x.shape[0]
        optimizer.zero_grad()
        lst.backward()
        optimizer.step()

This gives me the error of trying to backward through the graph a second time, and that I must specify retain_graph=True. My questions are:

  1. Why is retain_graph=True necessary? To my understanding, I am “unfolding” the network 60 timesteps and only doing a backward pass on the last timestep. What exactly needs to be remembered from batch to batch?
  2. Is there a more optimal/“better” way of doing truncated backpropagation? I was thinking I could backpropagate the loss every time one timestep is unfolded, but am not sure if that would be a big improvement. See here (https://r2rt.com/styles-of-truncated-backpropagation.html) for what I mean - specifically the picture before the section titled “Experiment design”.
  3. Any other comment or suggestion on code is appreciated… I’m relatively new to PyTorch so not sure what best practices are.

Thanks!

Try moving init_hidden into the inner loop. You need to initialize your hidden state for each batch, not just for each epoch.

Thanks, that works. Though for others who are reading, a better solution would be to detach the hidden state at the start of each batch rather than to re-initialize, allowing the RNN to “carry-over” the final state from previous batch - this is standard BPTT.

1 Like

Good point with detach()!

What is actually the benefit and intuition of carrying the hidden state over? Say I have do just sentence classifications, and each sentences are independent. Why should one batch depend on the final hidden of a previous batch. Here, re-initialization seems to me the more consistent method.

Sure for a language model where you break long documents into chunks, so the chunks depend on each other, carrying the hidden state over seems the more natural way to go.

I agree, the benefits aren’t as clear when batches are truly independent from one another. In my case, I’m doing time series prediction where the batches are created sequentially, so carrying the hidden state over does seem more natural, as you suggested.

However, in sentence classifications, one could also argue that the hidden states capture syntactic rules that is universal in how most sentences are formed, and thus few sentences are truly “independent” from one another, so that carrying the hidden state would be beneficial. I’m sure there has to be research done to answer that.

@spaceman thanks! Yes, that would be my intuition as well.