Gradient of hidden state in TBPTT

Hi everyone!

I’m trying to implement Truncated Backpropagation Through Time (TBPTT) for an LSTM model based on this post.

This is my code:

batch_size, sequence_length = inputs.shape[0], inputs.shape[1]
retain_graph = self.__k1 < self.__k2

for i in range(batch_size):

     # get_init_state creates the initial zero tensor for the hidden state
     states = [(None, self.__model.get_init_state(batch_size=1))]
     inputs, label = inputs[i], torch.tensor([labels[i].item()])
     
     # sequence_length is the number of timesteps
     for j in range(sequence_length):

           state = states[-1][1].detach()
           state.requires_grad = True
           state.retain_grad()

           time_step = inputs[j, :].unsqueeze(0).unsqueeze(1)

           # predict calls forward on the model (LSTM)
           output, new_state = self.__model.predict(time_step, state)
           output, new_state = output.to(self.__device), new_state.to(self.__device)
           new_state.retain_grad()
           states.append((state.clone(), new_state.clone()))

           # Delete old states
           while len(states) > self.__k2:
               del states[0]

           # Backprop every k1 steps
           if (j + 1) % self.__k1 == 0:
            
               loss = self.__criterion(outputs, labels)
               self.__optimizer.zero_grad()
               loss.backward(retain_graph=retain_graph)

               # Backprop over the last k2 states
               for k in range(self.__k2 - 1):

                   if states[-k - 2][0] is None:
                       break
                   
                   # *** This gradient is None and is causing the problem ***
                   curr_gradient = states[-k - 1][0].grad

                   # This does not work since curr_gradient is None
                   states[-k - 2][1].backward(curr_gradient, retain_graph=True)

                self.__optimizer.step()

The problem is that the gradient of the states is not updated when backpropagating at loss.backward(). Thus, when I try to backpropagate over the states (at states[-k - 2][1].backward(curr_gradient, retain_graph=True)) it rightfully raises an error since current_grad is None.

I think the problem may be due to the fact that the states are not bound to the computational graph of loss and therefore their gradient is not updated during the backpropagation. I already made sure each and every state has required_grad = True.

Does anybody have any hint for solving this problem? In general, are hidden states bound to the same computational graph of the outputs when using RNNs?

Thanks!