Hi everyone!
I’m trying to implement Truncated Backpropagation Through Time (TBPTT) for an LSTM model based on this post.
This is my code:
batch_size, sequence_length = inputs.shape[0], inputs.shape[1]
retain_graph = self.__k1 < self.__k2
for i in range(batch_size):
# get_init_state creates the initial zero tensor for the hidden state
states = [(None, self.__model.get_init_state(batch_size=1))]
inputs, label = inputs[i], torch.tensor([labels[i].item()])
# sequence_length is the number of timesteps
for j in range(sequence_length):
state = states[-1][1].detach()
state.requires_grad = True
state.retain_grad()
time_step = inputs[j, :].unsqueeze(0).unsqueeze(1)
# predict calls forward on the model (LSTM)
output, new_state = self.__model.predict(time_step, state)
output, new_state = output.to(self.__device), new_state.to(self.__device)
new_state.retain_grad()
states.append((state.clone(), new_state.clone()))
# Delete old states
while len(states) > self.__k2:
del states[0]
# Backprop every k1 steps
if (j + 1) % self.__k1 == 0:
loss = self.__criterion(outputs, labels)
self.__optimizer.zero_grad()
loss.backward(retain_graph=retain_graph)
# Backprop over the last k2 states
for k in range(self.__k2 - 1):
if states[-k - 2][0] is None:
break
# *** This gradient is None and is causing the problem ***
curr_gradient = states[-k - 1][0].grad
# This does not work since curr_gradient is None
states[-k - 2][1].backward(curr_gradient, retain_graph=True)
self.__optimizer.step()
The problem is that the gradient of the states is not updated when backpropagating at loss.backward()
. Thus, when I try to backpropagate over the states (at states[-k - 2][1].backward(curr_gradient, retain_graph=True)
) it rightfully raises an error since current_grad
is None
.
I think the problem may be due to the fact that the states are not bound to the computational graph of loss
and therefore their gradient is not updated during the backpropagation. I already made sure each and every state has required_grad = True
.
Does anybody have any hint for solving this problem? In general, are hidden states bound to the same computational graph of the outputs when using RNNs?
Thanks!