Multi-Agent Advantage calculation is leading to in-place gradient error

I am working on some multi-agent RL training using PPO.

As part of that, I need to calculate the advantage on a per-agent basis which means that I’m taking the data generated by playing the game and masking out parts of it at a time.

This has led to an in-place error that’s killing the gradient and pytorch’s anomaly detection = True stack trace shows me the value function output from my NN.

Here’s a gist of the appropriate code with the learning code separated out: cleanRL · GitHub

I found this post where they were getting a similar error and fixed it by saving intermediate data into a list and then making that list of tensors into the complete tensor at the end. However, that solution is not working.

Here’s (what I think is) the offending code:

    with torch.no_grad():
        sort_list = []
        advantages = []
        returns = []
        indices = torch.arange(0, rewards.shape[0]).long().to(device)
        next_value = learner.get_value(torch.FloatTensor(next_obs).unsqueeze(0).to(device))
        for player in ['player_0', 'player_1']:
            mask = np.array(rollouts.agent_id) == player
            masked_inds = list(indices[mask])
            lastgaelam = 0
            for t in reversed(range(mask.sum())):
                if t == mask.sum() - 1:
                    nextnonterminal = 1.0 - next_done
                    nextvalues = next_value
                else:
                    nextnonterminal = 1.0 - dones[mask][t + 1]
                    nextvalues = values[mask][t + 1]
                delta = rewards[mask][t] + gamma * nextvalues * nextnonterminal - values[mask][t]
                lastgaelam = delta + gamma * gae_lambda * nextnonterminal * lastgaelam
                advantages.append(lastgaelam)
                sort_list.append(masked_inds[t])
                returns.append(lastgaelam + values[mask][t])

        advantages = torch.cat(advantages)[torch.LongTensor(sort_list)].to(device)
        returns = torch.cat(returns)[torch.LongTensor(sort_list)].to(device)

Here’s how the code looks when doing single-agent learning and there’s still indexing and setting into the advantage tensor:

        with torch.no_grad():
            next_value = agent.get_value(next_obs).reshape(1, -1)
            advantages = torch.zeros_like(rewards).to(device)
            lastgaelam = 0
            for t in reversed(range(args.num_steps)):
                if t == args.num_steps - 1:
                    nextnonterminal = 1.0 - next_done
                    nextvalues = next_value
                else:
                    nextnonterminal = 1.0 - dones[t + 1]
                    nextvalues = values[t + 1]
                delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
                advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
            returns = advantages + values

I guess it’s the masking that is the issue, but I don’t see how else to compute what I need without it.


Here’s an example of multi-agent PPO where all the agents move simultaneously which then means that you can vectorize the advantage calculations along the agent dimension of the data buffers: pettingzoo. farama. org/tutorials/cleanrl/implementing_PPO/

but that doesn’t work for extensive form games where the data is naturally interleaved, hence the masking solutions to independently calculate things.


Full error if anyone wants it:

/home/roque/miniconda3/envs/mapo/lib/python3.9/site-packages/torch/autograd/__init__.py:197: UserWarning: Error detected in AddmmBackward0. Traceback of forward call that caused the error:
  File "/mnt/d/PycharmProjects/ubc/mapo/c4_train.py", line 264, in <module>
    next_done, batch_reward_main, batch_reward_opp, batch_opponents_scores) = generate_data(
  File "/mnt/d/PycharmProjects/ubc/mapo/c4/ppo.py", line 55, in generate_data
    rollout_results = rollout(env, learner, opponent, f"main_v{opponent_id}", device)
  File "/mnt/d/PycharmProjects/ubc/mapo/c4/utils.py", line 138, in rollout
    action, logprob, entropy, value, logits = main_policy.get_action_and_value(observation)
  File "/mnt/d/PycharmProjects/ubc/mapo/c4_train.py", line 184, in get_action_and_value
    return action, probs.log_prob(action), probs.entropy(), self.critic(x), logits.cpu().detach()
  File "/home/roque/miniconda3/envs/mapo/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/roque/miniconda3/envs/mapo/lib/python3.9/site-packages/torch/nn/modules/container.py", line 204, in forward
    input = module(input)
  File "/home/roque/miniconda3/envs/mapo/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/roque/miniconda3/envs/mapo/lib/python3.9/site-packages/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
  File "/home/roque/miniconda3/envs/mapo/lib/python3.9/site-packages/torch/fx/traceback.py", line 57, in format_stack
    return traceback.format_stack()
 (Triggered internally at /opt/conda/conda-bld/pytorch_1670525551200/work/torch/csrc/autograd/python_anomaly_mode.cpp:114.)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Traceback (most recent call last):
  File "/mnt/d/PycharmProjects/ubc/mapo/c4_train.py", line 279, in <module>
    loss, entropy_loss, pg_loss, v_loss, explained_var, approx_kl, meanclipfracs, old_approx_kl = update_model(
  File "/mnt/d/PycharmProjects/ubc/mapo/c4/ppo.py", line 248, in update_model
    loss.backward(retain_graph=True)
  File "/home/roque/miniconda3/envs/mapo/lib/python3.9/site-packages/torch/_tensor.py", line 488, in backward
    torch.autograd.backward(
  File "/home/roque/miniconda3/envs/mapo/lib/python3.9/site-packages/torch/autograd/__init__.py", line 197, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [256, 1]], which is output 0 of AsStridedBackward0, is at version 4; expected version 3 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

Hi Aaron!

Starting with the forward-call traceback, look at the last couple of lines
in your code (that are then followed by calls into pytorch infrastructure).

I would certainly drill down into what self.critic (x) is doing.

[Edit: Some further words of clarification / explanation: As I’ve come to
understand it, anomaly detection’s forward-call traceback flags the
operation in the forward pass whose backward pass is being blocked
by the inplace modification of some tensor required by the backward
pass (rather than flagging the operation that modifies that tensor).

Given that you are using retain_graph = True, I speculate that you
are doing something like:

critic_loss.backward (retain_graph = True)
critic_optimizer.step()   # modifies critic's parameters inplace
...
actor_loss.backward()     # where actor loss depends on critic
actor_optimizer.step()

If so, you will try to backpropagate again through critic, which has had
its parameters modified inplace by its optimizer. Whether or not modifying
a tensor inplace will cause an inplace-modification error depend on the
details of whether that tensor is needed in the backward-pass computation,
but it is likely that at least some of critic’s parameters will be needed in
the backward pass.

If you have identified the tensor that is causing the inplace-modification
error – likely one of critic’s parameters – print out its ._version before
and after calling critic_optimizer.step(). If you can’t identify the tensor
in question, you can still test this theory by commenting out the call to
critic_optimizer.step() and see if this particulate inplace-modification
error goes away. (There may be others.)]

Also, look at the inplace-modification error itself. It is telling you that a
FloatTensor of shape [256, 1] is the tensor that is being modified
inplace. Where in your code do you have a tensor of that shape (that
occurs somewhere in the forward pass)? Look closely at how it is being
used and if you can see an inplace modification.

Note that the error message is complaining that it should be of “version 3”
rather than of “version 4.” Print out the tensor’s ._version property at
various strategic places in your code. The inplace modification is occurring
somewhere between ._version values of 3 and 4. You can insert
intermediate ._version print statements to perform a binary search to
locate exactly where the inplace modification is occurring.

For example, you might try printing out ._version just before and just
after the call to self.critic (x) upon which the forward-call traceback
casts some suspicion.

Note also that you are calling .backward (retain_graph = True).
First, make sure that this is correct logic for your use case. If it is, be
aware that calling optimizer.step() performs an inplace modification
of the parameters being optimized by optimizer. Again, you can check
this by printing out ._version for the problematic tensor before and after
the call to optimizer.step().

For some examples that illustrate these inplace-modification debugging
techniques, see this post:

Good luck!

K. Frank

The retain_graph=True was there because the previous error told to do that to get more info along with setting anomaly_detection mode on. It’s not usually on in my code.

Here’s the agent architecture, and the only place I have a [256, 1] tensor is the last layer of the critic:

class Agent(nn.Module):
    def __init__(self, envs):
        super().__init__()
        self.critic = nn.Sequential(OrderedDict([
            ('fc1', layer_init(nn.Linear(envs.single_observation_space.shape[0], 512))),
            ('relu1', nn.ReLU()),
            ('fc2', layer_init(nn.Linear(512, 256))),
            ('relu2', nn.ReLU()),
            ('critic_out', layer_init(nn.Linear(256, 1), std=1.0)),
        ]))
        self.actor = nn.Sequential(OrderedDict([
            ('fc1', layer_init(nn.Linear(envs.single_observation_space.shape[0], 512))),
            ('relu1', nn.ReLU()),
            ('fc2', layer_init(nn.Linear(512, 256))),
            ('relu2', nn.ReLU()),
            ('actor_out', layer_init(nn.Linear(256, envs.single_action_space.n), std=0.01)),
        ]))

    def get_value(self, x):
        return self.critic(x)

    def get_action_and_value(self, x, action=None):
        if x.ndim == 1:
            x = x.unsqueeze(0)
        logits = self.actor(x)
        probs = torch.distributions.Categorical(logits=logits)
        if action is None:
            action = probs.sample()
        return action, probs.log_prob(action), probs.entropy(), self.critic(x), logits.cpu().detach()

So this is wild.

I started tracking the version number of the final linear layer of the critic network (and of the tensors that were the output of the value function and all things derived from the value function) and they’re all showing 1’s and 0’s as their versions.

To make things go faster, I shrunk the number of games that get played to create the on-policy dataset to 10, and now things just work. But once that number gets big enough (~25 games) then the error appears.

Here’s the code (via a zip to preserve file structure – it’s just the files in the gist structured appropriately): ppo_wip.zip - Google Drive


Just saw your edit:

I am updating the entire network with one backward call.

            # Total loss
            loss = pg_loss - ent_coef * entropy_loss + v_loss * vf_coef + novelty_loss * 0.01
            # right now the novelty loss is 0 (locked behind an if statement that will always evaluate to False

            optimizer.zero_grad()
            # print("pre backward: ", learner.critic.critic_out._version)
            loss.backward(retain_graph=True)
            # print("post backward: ", learner.critic.critic_out._version)
            nn.utils.clip_grad_norm_(learner.parameters(), max_grad_norm)
            # print("post clip pre step: ", learner.critic.critic_out._version)
            optimizer.step()
            # print("post step: ", learner.critic.critic_out._version)

Hi Aaron!

Try getting rid of retain_graph = True. Whether or not it helps in
debugging, an unnecessary retain_graph = True could well be the
cause of your problem. (Even a necessary retain_graph = True
could cause trouble, but then fixing things can be more involved.)

This makes sense. The .weight property of a Linear (256, 1) has
shape [1, 256], and, as I understand things, its transpose, with shape
[256, 1], is used in the backward pass, so that all agrees with your
error messages.

Well, that’s a bit of a puzzle.

But, nevertheless, try printing out
critic.critic_out.weight._version before and after calling
optimizer.step() and see if it increases consistently with the
error message (at least in those runs where the error occurs).

This should be fine with the exception of retain_graph = True. Since
your code logic doesn’t reuse the graph of loss, you should be fine
getting rid of it.

(If you keep retain_graph = True, you do reuse parts of loss’s graph
the next time through your training loop, and parts of that graph
(presumably critic_out.weight) will have been modified inplace
by optimizer.step().)

So try loss.backward() without retain_graph = True and see if that
fixes your error.

(As an aside, I’m not familiar with clip_grad_norm_(), but it does modify
your gradients inplace. Maybe the gradients only actually get clipped if
the grad_norm exceeds max_grad_norm and maybe that only happens
when you run 25 games rather than 10. Whether this has anything to do
with your actual error, I don’t know.)

Best.

K. Frank

So this just keeps getting wilder. Good news, I know how to run the code. Bad news, it makes no sense to me why I need to do what I need to do to get it to run.

I played around with the batch sizes and it seems to be that when I have multiple minibatches is when everything breaks. If I increase the number of games played to make up each update’s dataset such that it becomes big enough that I need to split it into minibatches, the error turns up.

But if I also increase the batch size it all works.

Alternatively, if I increase the number of optimization steps I can take on a single batch, that works fine. This is very odd.

def update_model(learner, learner_id, optimizer, env, data, n_steps, n_opt_steps, minibatch_size, gamma, gae_lambda,
                 clip_coef, norm_adv, clip_vloss, max_grad_norm, ent_coef, vf_coef, target_kl, device,
                 opponent=None, archive=None, ArchiveKL: Optional["ArchiveKL"]=None):
    # get data for the model we want to update
    rollouts = data[learner_id]
    obs = rollouts.obs
    actions = rollouts.actions
    logprobs = rollouts.logprobs
    rewards = rollouts.rewards
    dones = rollouts.dones
    values = rollouts.values
    logits = rollouts.logits
    next_obs = data['next_obs']
    next_done = data['next_dones']

    # print("extracted from storage pre GAE: ", values._version)
    # print("start of opt: ", learner.critic.critic_out._version)

    # bootstrap value if not done
    with torch.no_grad():
        sort_list = []
        advantages = []
        returns = []
        indices = torch.arange(0, rewards.shape[0]).long().to(device)
        next_value = learner.get_value(torch.FloatTensor(next_obs).unsqueeze(0).to(device))
        for player in ['player_0', 'player_1']:
            mask = np.array(rollouts.agent_id) == player
            masked_inds = list(indices[mask])
            lastgaelam = 0
            for t in reversed(range(mask.sum())):
                if t == mask.sum() - 1:
                    nextnonterminal = 1.0 - next_done
                    nextvalues = next_value
                else:
                    nextnonterminal = 1.0 - dones[mask][t + 1]
                    nextvalues = values[mask][t + 1]
                delta = rewards[mask][t] + gamma * nextvalues * nextnonterminal - values[mask][t]
                lastgaelam = delta + gamma * gae_lambda * nextnonterminal * lastgaelam
                advantages.append(lastgaelam)
                sort_list.append(masked_inds[t])
                returns.append(lastgaelam + values[mask][t])

        advantages = torch.cat(advantages)[torch.LongTensor(sort_list)].to(device)
        returns = torch.cat(returns)[torch.LongTensor(sort_list)].to(device)

    # print("extracted from storage post GAE: ", values._version)
    # print("return version: ", returns._version)

    # print("after gae: ", learner.critic.critic_out._version)

    # flatten the batch
    b_obs = obs.reshape((-1,) + env.single_observation_space.shape)
    b_logprobs = logprobs.reshape(-1)
    b_actions = actions.reshape((-1,) + env.single_action_space.shape)
    b_advantages = advantages.reshape(-1)
    b_returns = returns.reshape(-1)
    b_values = values.reshape(-1)
    b_logits = logits.reshape((-1,) + (env.single_action_space.n,))

    # print("flattened: ", b_values._version)

    # Optimizing the policy and value network
    batch_size = b_obs.shape[0]
    print(batch_size, minibatch_size)
    print(list(range(0, batch_size, minibatch_size)))
    b_inds = np.arange(batch_size)
    clipfracs = []
    for epoch in range(n_opt_steps):
        np.random.shuffle(b_inds)
        for start in range(0, batch_size, minibatch_size):
            end = start + minibatch_size
            mb_inds = b_inds[start:end]

            _, newlogprob, entropy, newvalue, _ = learner.get_action_and_value(b_obs[mb_inds].unsqueeze(1),
                                                                               b_actions.long()[mb_inds])

            # print("getting new log probs: ", learner.critic.critic_out._version)

            logits_archive = []
            if ArchiveKL is not None:
                # get logits for observations from each agent in the novelty archive
                for policy_id, policy in archive.items():
                    if policy_id in ["main", "greedy_main"]:
                        continue
                    # load policy into random agent
                    opponent.load_state_dict(policy)
                    logits_archive.append(opponent.get_action_and_value(b_obs[mb_inds].unsqueeze(1),
                                                                        b_actions.long()[mb_inds])[4])
                logits_archive = torch.stack(logits_archive, dim=0).detach()

            logratio = newlogprob - b_logprobs[mb_inds]
            ratio = logratio.exp()

            with torch.no_grad():
                # calculate approx_kl http://joschu.net/blog/kl-approx.html
                old_approx_kl = (-logratio).mean()
                approx_kl = ((ratio - 1) - logratio).mean()
                clipfracs += [((ratio - 1.0).abs() > clip_coef).float().mean().item()]

            mb_advantages = b_advantages[mb_inds]
            if norm_adv:
                mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)

            # Policy loss
            pg_loss1 = -mb_advantages * ratio
            pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - clip_coef, 1 + clip_coef)
            pg_loss = torch.max(pg_loss1, pg_loss2).mean()

            # Value loss
            newvalue = newvalue.view(-1)
            if clip_vloss:
                v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2
                v_clipped = b_values[mb_inds] + torch.clamp(
                    newvalue - b_values[mb_inds],
                    -clip_coef,
                    clip_coef,
                )
                v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2
                v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
                v_loss = 0.5 * v_loss_max.mean()
            else:
                v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean()

            # Entropy loss
            entropy_loss = entropy.mean()

            # Novelty loss
            # try to maximize the KL divergence between the current policy and the archive
            novelty_loss = 0
            if ArchiveKL is not None:
                novelty_loss = ArchiveKL.forward(b_logits[mb_inds].unsqueeze(-1), logits_archive).mean()

            # Total loss
            loss = pg_loss - ent_coef * entropy_loss + v_loss * vf_coef + novelty_loss * 0.01

            optimizer.zero_grad()
            # print("pre backward: ", learner.critic.critic_out._version)
            loss.backward()
            # print("post backward: ", learner.critic.critic_out._version)
            nn.utils.clip_grad_norm_(learner.parameters(), max_grad_norm)
            # print("post clip pre step: ", learner.critic.critic_out._version)
            optimizer.step()
            # print("post step: ", learner.critic.critic_out._version)

        if target_kl is not None:
            if approx_kl > target_kl:
                break

    # print("b_returns final version: ", b_returns._version)
    y_pred, y_true = b_values.cpu().detach().numpy(), b_returns.cpu().detach().numpy()
    var_y = np.var(y_true)
    explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y

    return (loss.item(), entropy_loss.item(), pg_loss.item(), v_loss.item(),
            explained_var, approx_kl.item(), np.mean(clipfracs), old_approx_kl.item()
            )

Hi Aaron!

Under some conditions, the error occurs; under others, it doesn’t.

This raises the question of whether you’re seeing the same error every
time, or different, but similar, errors.

To be systematic about this, when the error does occur, do you get
essentially the same error message, including the part in the forward-call
traceback?

Does the same tensor get modified inplace and can you use ._version
to track down where the tensor is being modified and verify that it is
being modified in the same place?

Agent.get_action_and_value() calls self.critic (x), but it
appears that you have multiple kinds of Agents. From the snippets
you have posted, I see main_policy.get_action_and_value(),
learner.get_action_and_value(), and
opponent.get_action_and_value().

Just speculation, but if, for example, you randomly choose an opponent
from a pool of opponents, then maybe your batch size or the number of
games played changes the probability that you choose an opponent that
has already been used and has a self.critic whose parameters have
already been modified inplace by an optimization step, so that when you
backward through the opponent.get_action_and_value() call (and
therefore backward through its call to self.critic (x)), you get the
inplace-modification error.

Or something like that.

In any event, you need to determine which specific tensor is being modified
inplace and which specific line of code is making that inplace modification in
order to reliably fix the problem (and to figure out whether you have just one
such error or a cluster of similar errors).

It does look like the self.critic.critic_out.weight tensor (whose
transpose has shape [256, 1]) is being modified inplace. But, since
you have multiple Agents, which Agent’s critic is being modified?

Good luck!

K. Frank

Fair. I definitely need to clean up the naming terminology.

main_policy and learner are the same network (learner is the name of the model passed into the update_model function so that I can pass in other NNs to optimization code) and the opponent is never getting called (locked behind if statements that evaluate to False while I get the normal learning working).

Thanks for all your pointers!