Hi everyone!

I’m trying to implement Truncated Backpropagation Through Time (TBPTT) for an LSTM model based on this post.

This is my code:

```
batch_size, sequence_length = inputs.shape[0], inputs.shape[1]
retain_graph = self.__k1 < self.__k2
for i in range(batch_size):
# get_init_state creates the initial zero tensor for the hidden state
states = [(None, self.__model.get_init_state(batch_size=1))]
inputs, label = inputs[i], torch.tensor([labels[i].item()])
# sequence_length is the number of timesteps
for j in range(sequence_length):
state = states[-1][1].detach()
state.requires_grad = True
state.retain_grad()
time_step = inputs[j, :].unsqueeze(0).unsqueeze(1)
# predict calls forward on the model (LSTM)
output, new_state = self.__model.predict(time_step, state)
output, new_state = output.to(self.__device), new_state.to(self.__device)
new_state.retain_grad()
states.append((state.clone(), new_state.clone()))
# Delete old states
while len(states) > self.__k2:
del states[0]
# Backprop every k1 steps
if (j + 1) % self.__k1 == 0:
loss = self.__criterion(outputs, labels)
self.__optimizer.zero_grad()
loss.backward(retain_graph=retain_graph)
# Backprop over the last k2 states
for k in range(self.__k2 - 1):
if states[-k - 2][0] is None:
break
# *** This gradient is None and is causing the problem ***
curr_gradient = states[-k - 1][0].grad
# This does not work since curr_gradient is None
states[-k - 2][1].backward(curr_gradient, retain_graph=True)
self.__optimizer.step()
```

The problem is that the gradient of the states is not updated when backpropagating at `loss.backward()`

. Thus, when I try to backpropagate over the states (at `states[-k - 2][1].backward(curr_gradient, retain_graph=True)`

) it rightfully raises an error since `current_grad`

is `None`

.

I think the problem may be due to the fact that the states are not bound to the computational graph of `loss`

and therefore their gradient is not updated during the backpropagation. I already made sure each and every state has `required_grad = True`

.

Does anybody have any hint for solving this problem? In general, are hidden states bound to the same computational graph of the outputs when using RNNs?

Thanks!