Understanding and resolving "Trying to backward through the graph a second time" error

I’ve read a lot of posts of similar problems, but none of the solutions appear to help me.

I’m implementing MADDGP and get the: “Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed)…” error. I have a for loop that iterates through my agents and performs a training step. The first agent trains fine, and then the error occurs when the second agent attempts loss.backward().

I don’t understand why that would be where the error occurs. To my knowledge, that’s the first time a loss is calculated for that agent’s optimizer. I read the issue may be related to the memory samples possibly being used in the gradient, but I think I circumvented that with torch.no_grad().

Anyway, here is the related code, please let me know where I’m going wrong. I appreciate the help.

for i in range(num_games):
        observations = env.reset()
        score = 0
        done = [False]*num_agents
        episode_step = 0
        while not any(done):
            with torch.no_grad():
                actions = trainer.choose_actions(observations)
                new_observations, reward, done, info = env.step(actions)

                previous_state = observation_list_to_state_vector(observation=observations)
                new_state = observation_list_to_state_vector(observation=new_observations)

                if episode_step >= max_steps:
                    done = [True]*num_agents

                memory.store_transition(observations, previous_state, actions, reward, new_observations, new_state, done)

            if total_steps % 100 == 0:

            observations = new_observations

            score += sum(reward)
            total_steps += 1
            episode_step += 1
def learn(self, memory : MultiAgentReplayBuffer):
        if not memory.ready_for_sample():
        actor_states, states, actions, rewards, \
            actor_new_states, new_states, dones = memory.sample_buffer()

        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

        states = torch.tensor(states, dtype=torch.float).to(device)
        actions = torch.tensor(actions, dtype=torch.float).to(device)
        rewards = torch.tensor(rewards, dtype=torch.float).to(device)
        new_states = torch.tensor(new_states, dtype=torch.float).to(device)
        dones = torch.tensor(dones).to(device)

        all_agents_new_actions = []
        all_agents_new_mu_actions = []
        old_agents_actions = []

        for agent_index, agent in enumerate(self.agents):
            new_state_for_agent = torch.tensor(actor_new_states[agent_index], dtype=float).to(device)
            new_pi = agent.target_actor.forward(new_state_for_agent)

            mu_states = torch.tensor(actor_states[agent_index], dtype=float).to(device)
            pi = agent.actor.forward(mu_states)


        new_actions = torch.cat([actions for actions in all_agents_new_actions], dim=1)
        mu = torch.cat([actions for actions in all_agents_new_mu_actions], dim=1)
        old_actions = torch.cat([actions for actions in old_agents_actions], dim=1)

        for agent_index, agent in enumerate(self.agents):
            self.learn_agent(agent_index, agent, rewards, states, old_actions, new_actions, new_states, mu, dones)

    def learn_agent(self, agent_index, agent : Agent, rewards, states, old_actions, new_actions, new_states, mu, dones):
        new_critic_value = agent.target_critic.forward(new_states, new_actions).flatten()
        # new_critic_value[dones[:,0]] = 0.0 
        new_critic_value = torch.where(dones[:,0], torch.tensor([0.0]), new_critic_value)
        critic_value = agent.critic.forward(states, old_actions).flatten()

        target = rewards[:, agent_index] + agent.gamma*new_critic_value

        critic_loss = F.mse_loss(target, critic_value).clone()        
        actor_loss = -torch.mean(agent.critic.forward(states, mu).flatten()).clone()




        agent.update_network_parameters() # updates target networks every n steps

In learn_agent you are calling backward on critic_loss and actor_loss, which is inside the for agent_index, agent in enumerate(self.agents) loop.
I would probably start by checking if any of the inputs to learn_agent are attached to a computation graph and are used in the loss computation, since the repeated backward call would then cause the mentioned error.

You are correct. And I figured out why I wasn’t able to get it working before. Previously I was using .detach() assuming it worked in place. Now that I am using .detach_() it is working.

Thank you so much!