Gradient of hidden state in TBPTT

Hi everyone!

I’m trying to implement Truncated Backpropagation Through Time (TBPTT) for an LSTM model based on this post.

This is my code:

batch_size, sequence_length = inputs.shape[0], inputs.shape[1]
retain_graph = self.__k1 < self.__k2

for i in range(batch_size):

     # get_init_state creates the initial zero tensor for the hidden state
     states = [(None, self.__model.get_init_state(batch_size=1))]
     inputs, label = inputs[i], torch.tensor([labels[i].item()])
     
     # sequence_length is the number of timesteps
     for j in range(sequence_length):

           state = states[-1][1].detach()
           state.requires_grad = True
           state.retain_grad()

           time_step = inputs[j, :].unsqueeze(0).unsqueeze(1)

           # predict calls forward on the model (LSTM)
           output, new_state = self.__model.predict(time_step, state)
           output, new_state = output.to(self.__device), new_state.to(self.__device)
           new_state.retain_grad()
           states.append((state.clone(), new_state.clone()))

           # Delete old states
           while len(states) > self.__k2:
               del states[0]

           # Backprop every k1 steps
           if (j + 1) % self.__k1 == 0:
            
               loss = self.__criterion(outputs, labels)
               self.__optimizer.zero_grad()
               loss.backward(retain_graph=retain_graph)

               # Backprop over the last k2 states
               for k in range(self.__k2 - 1):

                   if states[-k - 2][0] is None:
                       break
                   
                   # *** This gradient is None and is causing the problem ***
                   curr_gradient = states[-k - 1][0].grad

                   # This does not work since curr_gradient is None
                   states[-k - 2][1].backward(curr_gradient, retain_graph=True)

                self.__optimizer.step()

The problem is that the gradient of the states is not updated when backpropagating at loss.backward(). Thus, when I try to backpropagate over the states (at states[-k - 2][1].backward(curr_gradient, retain_graph=True)) it rightfully raises an error since current_grad is None.

I think the problem may be due to the fact that the states are not bound to the computational graph of loss and therefore their gradient is not updated during the backpropagation. I already made sure each and every state has required_grad = True.

Does anybody have any hint for solving this problem? In general, are hidden states bound to the same computational graph of the outputs when using RNNs?

Thanks!

Hi, did you figure out how to implement TBPTT on LSTMs ?
I’m posting a reply to refresh this post. Thanks.

Hi! Unfortunately, I ended up abandoning my attempt to implement TBPTT. Just in case you missed it, here is a more active thread on the topic that you may find interesting:Implementing Truncated Backpropagation Through Time - #7 by riccardosamperna

Thanks for your response. I def checked that thread, but their implementation is not quite clear and it yields to errors in the newer versions of Pytorch. Ignite engine has an implementation for TBPTT but only for very fixed scenarios (i.e., supervised trainer and k1= k2).