I’m trying to implement Truncated Backpropagation Through Time (TBPTT) for an LSTM model based on this post.
This is my code:
batch_size, sequence_length = inputs.shape, inputs.shape retain_graph = self.__k1 < self.__k2 for i in range(batch_size): # get_init_state creates the initial zero tensor for the hidden state states = [(None, self.__model.get_init_state(batch_size=1))] inputs, label = inputs[i], torch.tensor([labels[i].item()]) # sequence_length is the number of timesteps for j in range(sequence_length): state = states[-1].detach() state.requires_grad = True state.retain_grad() time_step = inputs[j, :].unsqueeze(0).unsqueeze(1) # predict calls forward on the model (LSTM) output, new_state = self.__model.predict(time_step, state) output, new_state = output.to(self.__device), new_state.to(self.__device) new_state.retain_grad() states.append((state.clone(), new_state.clone())) # Delete old states while len(states) > self.__k2: del states # Backprop every k1 steps if (j + 1) % self.__k1 == 0: loss = self.__criterion(outputs, labels) self.__optimizer.zero_grad() loss.backward(retain_graph=retain_graph) # Backprop over the last k2 states for k in range(self.__k2 - 1): if states[-k - 2] is None: break # *** This gradient is None and is causing the problem *** curr_gradient = states[-k - 1].grad # This does not work since curr_gradient is None states[-k - 2].backward(curr_gradient, retain_graph=True) self.__optimizer.step()
The problem is that the gradient of the states is not updated when backpropagating at
loss.backward(). Thus, when I try to backpropagate over the states (at
states[-k - 2].backward(curr_gradient, retain_graph=True)) it rightfully raises an error since
I think the problem may be due to the fact that the states are not bound to the computational graph of
loss and therefore their gradient is not updated during the backpropagation. I already made sure each and every state has
required_grad = True.
Does anybody have any hint for solving this problem? In general, are hidden states bound to the same computational graph of the outputs when using RNNs?