Hello, I’m trying to implement Deep Q learning with an lstm cell. I’m implementing something similar to this paper: https://arxiv.org/pdf/1509.03044.pdf. The lstm cell + linear layer network learns the hidden state from rewards, which is then passed into a separate DQN network for predicting the Q value. But I’m having trouble training the networks because of a backward runtime error.
Here is a snippet of my code:
# train RNN state model:
predict_reward, hidden_state = self.rnn_model(Variable(torch.from_numpy(state_batch).float()))
reward_var = Variable(torch.from_numpy(reward_batch).float(), requires_grad=False)
state_loss = nn.MSELoss()
state_loss = state_loss(predict_reward, reward_var)
self.rnn_optimizer.zero_grad()
state_loss.backward()
self.rnn_optimizer.step()
# generate target q values
target_q_output = self._generate_target_q_values(next_state_batch, reward_batch)
# get q net output after passing hidden_state into it
q_output = self.qnet(hidden_state)
q_output = q_output[range(self._mini_batch_size), action_indexs]
loss = F.smooth_l1_loss(q_output, target_q_output)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
I get an error on the second last line (loss.backward()):
line 370, in train_minibatch loss.backward()
File “/usr/local/lib/python3.6/site-packages/torch/autograd/variable.py”, line 156, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, retain_variables)
File “/usr/local/lib/python3.6/site-packages/torch/autograd/init.py”, line 98, in backward
variables, grad_variables, retain_graph)
File “/usr/local/lib/python3.6/site-packages/torch/autograd/function.py”, line 91, in apply
return self._forward_cls.backward(self, *args)
File “/usr/local/lib/python3.6/site-packages/torch/autograd/_functions/basic_ops.py”, line 52, in backward
a, b = ctx.saved_variables
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
I think the error is caused by passing the hidden state into the q network after calling backward on rnn_model (lstm network + linear layer). How can I fix this issue? Thanks!