RuntimeError: Trying to backward through the graph a second time. Reinforcement learning - DDPG

Hi,

I’m having the following issue in the second batch:

Traceback (most recent call last):
  File "NOT_REAL_PATH\train.py"", line 88, in <module>
    agent_log = agent.observe(observation, reward)
  File "NOT_REAL_PATH\agents\ddpg.py"", line 142, in observe
    return self.update()
  File "NOT_REAL_PATH\agents\ddpg.py"", line 164, in update
    q_loss.backward()
  File "NOT_REAL_ENV_PATH\lib\site-packages\torch\_tensor.py", line 488, in backward
    torch.autograd.backward(
  File "NOT_REAL_ENV_PATH\lib\site-packages\torch\autograd\__init__.py", line 197, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

Process finished with exit code 1

I’m trying to implement Deep Deterministic Policy Gradient as explained here.

Here’s my code:

    def __init__(...):
        ...
        self.Q_loss = torch.nn.MSELoss(reduction='mean')
        ...

    def update(self):
        cumulative_q_loss = 0
        cumulative_q = 0
        dataloader = self.replay_buffer.getDataloader(self.config.replay_buffer['sample_size'])
        for i, batch in enumerate(dataloader):
            states, actions, next_states, rewards = batch

            if i > self.config.replay_buffer['update_steps']:
                break

            # Critic update
            self.critic_optimizer.zero_grad()

            target_policy = self.target_actor(next_states)
            y = rewards + self.gamma * self.target_critic(torch.cat((next_states, target_policy), dim=1)).squeeze()

            q_loss = self.Q_loss(self.critic(torch.cat((states, actions), dim=1)).squeeze(), y)

            q_loss.backward()
            self.critic_optimizer.step()

            cumulative_q_loss += float(q_loss.detach().cpu())

            # Actor update
            self.actor_optimizer.zero_grad()

            policy = self.actor(states)
            negative_q = -torch.mean(self.critic(torch.cat((states, policy), dim=1)))

            negative_q.backward()
            self.actor_optimizer.step()

            cumulative_q += float(-negative_q.detach().cpu())

            # update target networks
            self.polyak_update(self.critic, self.target_critic)
            self.polyak_update(self.actor, self.target_actor)

        return {'mean_q_loss': cumulative_q_loss / self.config.replay_buffer['update_steps'],
                'mean_q': cumulative_q / self.config.replay_buffer['update_steps']}

    def polyak_update(self, network, target_network):
        """
        A weighted average between two sets of parameters.
        :param network: torch.nn.Module
        :param target_network: torch.nn.Module
        :return:
        """
        tau = self.config.tau

        with torch.no_grad():
            for parameter, parameter_target in zip(network.parameters(), target_network.parameters()):
                # We use an in-place operations "mul_", "add_" to update target
                # params, as opposed to "mul" and "add", which would make new tensors.
                parameter_target.data.mul_(1 - tau)
                parameter_target.data.add_(tau * parameter.data)

So the error says I’m trying to backward for a second time. I feel like I’ve zeroed out the gradients correctly, the input data is changed every batch so I’m not using that twice.

Thanks in advance and happy new year!

Should anyone finds this useful, the solution is to act without saving gradients:

with torch.no_grad():
    action = self.actor(state)

Then the action tensor will not require a gradient, and will be saved in the replay buffer like that. And it’s important that the input variables when updating have requires_grad=False, as I understand.