"RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [64, 1]], which is output 0 of AsStridedBackward0, is at version 3; expected version 2 instead. Hint: the backtrace further a

I am trying to make a Multi-Agent Deep reinforcement learning (Soft Actor-Critic) algorithm in pytorch.
And I got the following error:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [64, 1]], which is output 0 of AsStridedBackward0, is at version 3; expected version 2 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

I found some people asked the similar questions in this website.
I did some ways people recommeded and it did not worked.

Error appears when actor network is in update process.

Following codes are the learning process of my code.
The error happens in the following code.

It would be very appreciate if you help me.

def learn(self, memory):
    if not memory.ready():
        return

    actor_state, state, action, reward, actor_new_state, state_, done = memory.sample_buffer()

    state = T.tensor(state, dtype=T.float)
    action = T.tensor(action, dtype=T.float)
    reward = T.tensor(reward)
    state_ = T.tensor(state_, dtype=T.float)
    done = T.tensor(done)

    agents_actions = []

    actions_for_value_update = []
    log_probs_for_value_update = []
    for agent_idx, agent in enumerate(self.agents):
        a_state = T.tensor(actor_state[agent_idx], dtype=T.float)
        v_action, v_log_probs = agent.actor.sample_normal(a_state, reparameterize=False)
        v_log_probs = v_log_probs.view(-1)

        actions_for_value_update.append(v_action)
        log_probs_for_value_update.append(v_log_probs)

    actions_for_value_update = T.hstack((actions_for_value_update[:]))

    for agent_idx, agent in enumerate(self.agents):
        agents_actions.append(action[agent_idx])

        value = agent.value(state).view(-1)

        q1_new_policy = agent.critic1.forward(state=state, action=actions_for_value_update)
        q2_new_policy = agent.critic2.forward(state=state, action=actions_for_value_update)
        critic_value = T.min(q1_new_policy, q2_new_policy)
        critic_value = critic_value.view(-1)

        agent.value.optimizer.zero_grad()
        value_target = -log_probs_for_value_update[agent_idx].to(T.float32) + critic_value
        value_loss = 0.5 * F.mse_loss(value, value_target.detach())

        value_loss.backward(retain_graph=True)
        agent.value.optimizer.step()

        agent.update_network_parameters()

    actions = T.cat([acts for acts in agents_actions], dim=1)

    actions_for_actor_update = []
    log_probs_for_actor_update = []


    for agent_idx, agent in enumerate(self.agents):

        a_state = T.tensor(actor_state[agent_idx], dtype=T.float)
        a_action, a_log_probs = agent.actor.sample_normal(a_state.clone(), reparameterize=True)
        a_log_probs = a_log_probs.view(-1)
        actions_for_actor_update.append(a_action)
        # actions_for_actor_update.append(a_action)
        log_probs_for_actor_update.append(a_log_probs)

    actions_for_actor_update = (T.hstack((actions_for_actor_update[:]))).float()

    for agent_idx, agent in enumerate(self.agents):
        #print(agent_idx)
        q1_new_policy = agent.critic1.forward(state=state, action=actions_for_actor_update) 
        q2_new_policy = agent.critic2.forward(state=state, action=actions_for_actor_update)
        critic_value = T.min(q1_new_policy.clone(), q2_new_policy.clone())
        critic_value = critic_value.view(-1).clone()

        actor_loss = log_probs_for_actor_update[agent_idx].to(T.float32).clone() - critic_value
        actor_loss = T.mean(actor_loss)
        agent.actor.optimizer.zero_grad()
        actor_loss.backward(retain_graph=True)
        agent.actor.optimizer.step()

        value_ = agent.target_value(state_).view(-1)
        value_[done[:, 0]] = 0.0

        agent.critic1.optimizer.zero_grad()
        agent.critic2.optimizer.zero_grad()
        q_hat = 1.0 * reward[:, agent_idx].to(T.float32) + agent.gamma * value_

        q1_old_policy = agent.critic1.forward(state=state, action=actions).view(-1)
        q2_old_policy = agent.critic2.forward(state=state, action=actions).view(-1)
        critic1_loss = 0.5 * F.mse_loss(q1_old_policy, q_hat)
        critic2_loss = 0.5 * F.mse_loss(q2_old_policy, q_hat)

        critic_loss = critic1_loss + critic2_loss
        critic_loss.backward()
        agent.critic1.optimizer.step()
        agent.critic2.optimizer.step()

@ptrblck Hello. In the other links, I saw you solved the errors similar to my one (Above error).
I did tried a lot of things, but it didn’t worked.
I may ask you for a help in the error I faced. Thank you.

Often these issues can be raised by using retain_graph=True as a workaround for another error:

value_loss.backward(retain_graph=True)

Could you explain why retain_graph=True is used here, please?

1 Like

First, thanks for replying.

I found that retain_graph=True could be erased in updating value network (i.e., value_loss.backward()).

However, if I erase
retain_graph=True
when updating actor network (i.e., actor_loss.backward()), following error comes:
“RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.”

The above error comes when in Multi-agent environment. It does not appear when Single-agent environment.

Hello @Sangyoon1207

I have been faced with the same error in a MultiAgent environment. I am currently looking for a solution but I strongly suspect it would be something related to this post. Hopefully you find it helpful

Hi Fahmyadan and Sangyoon!

Here are some suggestions about how to track down (and maybe fix)
inplace-modification errors. Note that an inplace modification in the forward
pass is not necessarily* an error – it depends on whether and how the tensor
that was modified is used in the backward pass. Note that inplace operations
can be useful for saving memory – if you replace an innocent inplace operation
with an out-of-place equivalent, your training will use more memory (and, to a
minor extent, take more time).

First, understand what can be causes of inplace modifications. These include
inplace tensor functions, indicated by names with trailing underscores, such
as .zero_(), .copy_(), and .tanh_(). There are also some functions that
can be called with a boolean inplace flag that tell them to operate inplace,
such as torch.nn.functional.relu (x, inplace = True).

Writing into a tensor using indices, e.g., t[0, 1] = 10.0, modifies the tensor
inplace, as does, for example, .scatter_().

Fixing an inplace error can be as easy as replacing an inplace function with
an out-of-place version:

# instead of
s.tanh_()
# use
t = t.tanh()

or clone() the tensor before modifying it:

# instead of
s[0, 1] = 10.0
# use
t = t.clone()
t[0, 1] = 10.0

Let’s look at how to track down inplace errors – here’s an example script:

import torch
print (torch.__version__)

_ = torch.manual_seed (2023)

class DummyResNet (torch.nn.Module):
    def __init__ (self, printVersions = False):
        super().__init__()
        self.printVersions = printVersions
        self.lin1 = torch.nn.Linear (5, 10)
        self.lin2 = torch.nn.Linear (10, 10)
        self.lin3 = torch.nn.Linear (10, 1)
        self.relu = torch.nn.ReLU()
    def forward (self, x):
        y1 = self.lin1 (x)    # shape [3, 10]
        sc = self.relu (y1)   # shape [3, 10]
        y2 = self.lin2 (sc)   # shape [3, 10]
        y3 = self.relu (y2)   # shape [3, 10]
        y3 += sc   # skip connection
        y = self.lin3 (y3)    # shape [3, 1]
        if  self.printVersions:
            print ('versions:')
            print ('y1:', y1._version)
            print ('sc:', sc._version)
            print ('y2:', y2._version)
            print ('y3:', y3._version)
            print ('y: ', y._version)
        return y

drn = DummyResNet()
inp = torch.randn (3, 5)   # batch size of 3

print ('\nrun model and backward pass ...')
try:
    out = drn (inp)
    loss = out.sum()
    loss.backward()
except Exception as e:
    print ('%s: %s' % (type (e).__name__, e))

print ('\nrun forward pass with versions ...')
drn_versions = DummyResNet (True)
out = drn_versions (inp)

print ('\nrun again with anomaly detection ...')
with torch.autograd.detect_anomaly():
    out = drn (inp)
    loss = out.sum()
    loss.backward()

This is a toy version of a resnet that implements its skip connection with an
inplace add:

        y3 += sc   # skip connection

Here’s the output from the first part of the script:

1.13.1

run model and backward pass ...
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [3, 10]], which is output 0 of ReluBackward0, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

The inplace-modification error gives us two useful pieces of information:
First is the shape of the modified tensor (and its version number, see below);
and second is that ReluBackward0 is complaining about the modification.

Our toy model has two relu()s (both of which produce tensors of shape
[3, 10]). The first looks safe, while the second has its output modified
inplace by y3 += sc, and this is indeed the cause of the error.

But, as suggested by the error message, we can gather more information by
using autograd’s anomaly detection.

Later in the example script we rerun the forward / backward pass wrapped in
a with torch.autograd.detect_anomaly(): block. Here is the result:

run again with anomaly detection ...
<string>:46: UserWarning: Anomaly Detection has been enabled. This mode will increase the runtime and should only be enabled for debugging.
<path_to_pytorch_install>\torch\autograd\__init__.py:197: UserWarning: Error detected in ReluBackward0. Traceback of forward call that caused the error:
  File "<stdin>", line 1, in <module>
  File "<string>", line 47, in <module>
  File "<path_to_pytorch_install>\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "<string>", line 18, in forward
  File "<path_to_pytorch_install>\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "<path_to_pytorch_install>\torch\nn\modules\activation.py", line 102, in forward
    return F.relu(input, inplace=self.inplace)
  File "<path_to_pytorch_install>\torch\nn\functional.py", line 1457, in relu
    result = torch.relu(input)
  File "<path_to_pytorch_install>\torch\fx\traceback.py", line 57, in format_stack
    return traceback.format_stack()
 (Triggered internally at C:\cb\pytorch_1000000000000\work\torch\csrc\autograd\python_anomaly_mode.cpp:119.)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<string>", line 49, in <module>
  File "<path_to_pytorch_install>\torch\_tensor.py", line 488, in backward
    torch.autograd.backward(
  File "<path_to_pytorch_install>\torch\autograd\__init__.py", line 197, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [3, 10]], which is output 0 of ReluBackward0, is at version 1; expected version 0 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

In this, the “traceback of the forward call” shows us where the inplace
modification is being made. The parts relevant to the example script are:

File "<string>", line 47, in <module>

which is

    out = drn (inp)

where we call the forward pass of the model, and

File "<string>", line 18, in forward

which is

        y3 += sc   # skip connection

and is the inplace add that leads to the inplace-modification error.

So anomaly detection’s forward-pass traceback has led us rather directly
to the source of the error.

(In this case the error can be fixed, at the cost of memory, by replacing
y3 += sc with y3 = y3 + sc.)

Last, sometimes, even if you know what tensor is being modified, it can
be difficult to tell where the modification occurs, especially if you are
calling into third-party code or the tensor is being modified through an
alias or view.

Pytorch checks for inplace modifications by updating (at least in this current
pytorch version, 1.13.1) a tensor’s _version property every time a modification
is made. You can track this _version property to see if and when a tensor is
being modified. This is also illustrated by the example script:

run forward pass with versions ...
versions:
y1: 0
sc: 0
y2: 0
y3: 1
y:  0

Here we print out the _version of all of the tensors at the end of the
model’s forward() method and see that only y3 is being modified.
We didn’t do it, but we could have printed out y3._version just before
and after y3 += sc and seen that, indeed, this is the line that modifies
y3 inplace.

A rather different cause of inplace-modification errors can be the following:

optimizer.step() is also an inplace operation in that optimizer updates
the parameters it is training inplace. Normally these parameters are used
in a new forward pass that creates a new computation graph for which the
updated parameters are appropriate for the subsequent backward pass
(hence no inplace-modification error).

However, if retain_graph = True is used when performing a backward
pass, optimizer.step() is called, and then a second backward pass is
performed, the modified parameters can trigger an inplace-modification
error in the second backward pass.

Note that although retain_graph = True has legitimate use cases, there is
a lot of code floating around on the internet that uses retain_graph = True
incorrectly or unnecessarily.

This script uses a toy generative adversarial network to illustrate how
retain_graph = True can cause inplace-modification errors and how to
track them down:

import torch
print (torch.__version__)

_ = torch.manual_seed (2023)

nFeatures = 5   # number of "features" input to generator
length = 10     # length of 1d "image"

# not meant to illustrate how a GAN should actually be implemented

class DummyGen (torch.nn.Module):
    def __init__ (self):
        super().__init__()
        nHidden = 20   # number of hidden features
        self.lin1 = torch.nn.Linear (nFeatures, nHidden)
        self.relu = torch.nn.ReLU()
        self.lin2 = torch.nn.Linear (nHidden, length)
    def forward (self, x):
        y1 = self.lin1 (x)
        y2 = self.relu (y1)
        y3 = self.lin2 (y2)
        return y3

class DummyDisc (torch.nn.Module):
    def __init__ (self):
        super().__init__()
        nHidden = 20   # number of hidden features
        self.lin1 = torch.nn.Linear (length, nHidden)
        self.relu = torch.nn.ReLU()
        self.lin2 = torch.nn.Linear (nHidden, 1)
    def forward (self, x):
        y1 = self.lin1 (x)
        # y1 = torch.nn.functional.linear (x, self.lin1.weight.clone(), self.lin1.bias)
        y2 = self.relu (y1)
        y3 = self.lin2 (y2)
        # fix inplace error by using:
        # y3 = torch.nn.functional.linear (y2, self.lin2.weight.clone(), self.lin2.bias)
        return y3

gen = DummyGen()
dsc = DummyDisc()
g_opt = torch.optim.SGD (gen.parameters(), lr = 0.01)
d_opt = torch.optim.SGD (dsc.parameters(), lr = 0.01)

loss_fn = torch.nn.BCEWithLogitsLoss()

fake = gen (torch.randn (nFeatures))   # fake "image" from random noise
real = torch.randn (length)            # dummy data for "real" image

fake_pred = dsc (fake)
real_pred = dsc (real)

print ('\nrun discriminator backward pass and optimizer step ...')
d_loss = loss_fn (fake_pred, torch.tensor ([0.0])) + loss_fn (real_pred, torch.tensor ([1.0]))
d_opt.zero_grad()
d_loss.backward (retain_graph = True)
d_opt.step()

print ('\nrun generator backward pass and optimizer step ...')
try:
    g_loss = loss_fn (fake_pred, torch.tensor ([0.0])) + loss_fn (real_pred, torch.tensor ([1.0]))
    g_opt.zero_grad()
    g_loss.backward()
    g_opt.step()
except Exception as e:
    print ('%s: %s' % (type (e).__name__, e))

print ('\nrun again with anomaly detection ...')
with torch.autograd.detect_anomaly():
    fake = gen (torch.randn (nFeatures))   # fake "image" from random noise
    real = torch.randn (length)            # dummy data for "real" image
    
    fake_pred = dsc (fake)
    real_pred = dsc (real)

    print ('\nrun discriminator backward pass and optimizer step ...')
    d_loss = loss_fn (fake_pred, torch.tensor ([0.0])) + loss_fn (real_pred, torch.tensor ([1.0]))
    d_opt.zero_grad()
    d_loss.backward (retain_graph = True)
    d_opt.step()
    
    print ('\nrun generator backward pass and optimizer step ...')
    g_loss = loss_fn (fake_pred, torch.tensor ([0.0])) + loss_fn (real_pred, torch.tensor ([1.0]))
    g_opt.zero_grad()
    g_loss.backward()
    g_opt.step()

Here’s the output from the first part of the script;

1.13.1

run discriminator backward pass and optimizer step ...

run generator backward pass and optimizer step ...
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [20, 1]], which is output 0 of AsStridedBackward0, is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

The modified tensor has shape [20, 1]. The model does not obviously have any
tensors of this shape. However DummyDisc.lin2.weight has shape [1, 20]
(and is indeed the tensor that is causing the inplace-modification error). I assume
that the source of this shape comes from the backward pass using a transposed
view of the weight matrix.

Anomaly detection confirms this:

run again with anomaly detection ...
<string>:69: UserWarning: Anomaly Detection has been enabled. This mode will increase the runtime and should only be enabled for debugging.

run discriminator backward pass and optimizer step ...

run generator backward pass and optimizer step ...
<path_to_pytorch_install>\torch\autograd\__init__.py:197: UserWarning: Error detected in MmBackward0. Traceback of forward call that caused the error:
  File "<stdin>", line 1, in <module>
  File "<string>", line 74, in <module>
  File "<path_to_pytorch_install>\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "<string>", line 35, in forward
  File "<path_to_pytorch_install>\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "<path_to_pytorch_install>\torch\nn\modules\linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
  File "<path_to_pytorch_install>\torch\fx\traceback.py", line 57, in format_stack
    return traceback.format_stack()
 (Triggered internally at C:\cb\pytorch_1000000000000\work\torch\csrc\autograd\python_anomaly_mode.cpp:119.)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<string>", line 85, in <module>
  File "<path_to_pytorch_install>\torch\_tensor.py", line 488, in backward
    torch.autograd.backward(
  File "<path_to_pytorch_install>\torch\autograd\__init__.py", line 197, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [20, 1]], which is output 0 of AsStridedBackward0, is at version 3; expected version 2 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

The relevant lines are:

  File "<string>", line 74, in <module>
  File "<string>", line 35, in forward
  File "<path_to_pytorch_install>\torch\nn\modules\linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)

which are:

    real_pred = dsc (real)
        y3 = self.lin2 (y2)

and the forward() of pytorch’s Linear.

This again shows how autograd’s backward-pass RuntimeError, especially
together with anomaly detection, can help pin down the specific causes of
inplace-modification errors.

In this particular example DummyDisc.lin1 will also cause an inplace-modification
error – it just doesn’t show up in the output because autograd raises an error and
stops when the DummyDisc.lin2 error occurs.

There are a number of ways to fix these inplace-modification errors. If we
choose not to rework the logic of how the forward / backward passes and
optimization steps work with one another, probably the most direct fix is to
clone the weights of lin1 and lin2 and call the functional form of Linear,
that is:

        y1 = torch.nn.functional.linear (x, self.lin1.weight.clone(), self.lin1.bias)
        y3 = torch.nn.functional.linear (y2, self.lin2.weight.clone(), self.lin2.bias)

rather than:

        y1 = self.lin1 (x)
        y3 = self.lin2 (y2)

*) See the following post for some explanation of when inplace modifications
do and don’t lead to inplace-modification errors:

Best.

K. Frank

2 Likes

In the name of Allah

Hi, this explanation is very good and solved my problem using “.clone()”.
Thank you