Encounter the RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

I am going to define my layer. How ever, I encounter the RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation while running backward().

I found that if i commented the second for loop ‘for j in range(self.number_person):’ or make ‘u_i[:,j,:] = (1 - self.lumbda)*u_i[:,j,:]’, the backward() was fine.

I wonder where is inplace operation and why it does not work? ‘p_rnn_feature’ and ‘u_sum’ has been compute before.

BTW, this code is run on pytorch 0.19.7ad948f

def myNet():
    def __init__():
          #do som init
    def forward():
           #compute  p_rnn_feature,u_sum
           p_rnn_feature = Variable(torch.ones(p_rnn_feature.size())).cuda()
            u_sum = Variable(torch.ones(u_sum.size())).cuda()
            for i in range(self.embedding_length):
                u_s = u_s.clone()
                u_i = u_i.clone()

                for j in range(self.number_person):
                    
                    alpha_i = Variable(torch.zeros(batch_size, self.number_person, 1)).cuda()
                    comp_mask = Variable(j*torch.ones(valid_person_num.size())).cuda()
                    comp_mask = torch.lt(comp_mask, valid_person_num) # (batch_size, 1)
                    comp_mask_ui = comp_mask.repeat(1, self.hyper_d)
                    tmp_x = torch.cat((p_rnn_feature[:,j,:], u_sum[:,j,:], u_s), 1) # size: (batch_size, 2*rnn_cell_size+hyper_d)


                    u_i[:,j,:] = (1 - self.lumbda)*u_i[:,j,:] + self.lumbda*F.relu(self.u_i_linear(tmp_x))
                    u_i[:,j,:] = u_i[:,j,:]*comp_mask_ui.float()

                    alpha_i[:,j,:] = F.tanh(self.alpha_i_linear(torch.cat((u_i[:,j,:], u_s),1)))
                    alpha_i[:,j,:] = alpha_i[:,j,:]*comp_mask.float()

                alpha_sum = torch.sum(alpha_i,1)
                alpha_sum = alpha_sum.repeat(1,self.number_person,1)

                gate = alpha_i / Variable(torch.max(alpha_sum.data, torch.ones(alpha_sum.size()).cuda())).cuda()
                gate = gate.repeat(1,1,self.hyper_d)

                gated_ui_sum = gate*u_i
                gated_ui_sum = torch.sum(gated_ui_sum,1)
                gated_ui_sum = torch.squeeze(gated_ui_sum, dim=1)
                
                tmp_s = torch.cat((u_s, p_feature_sum, gated_ui_sum), 1)   # size: (batch_size, hyper_d+rnn_cell_size+hyper_d)
                u_s = (1 - self.lumbda) * u_s + self.lumbda * F.relu(self.u_s_linear(tmp_s))

            pred_tmp = torch.cat((torch.squeeze(torch.sum(u_i, 1), dim=1), u_s), 1)
            pred = self.pred_dropout(self.pred_linear(pred_tmp))
            pred = self.pred_linear_second(pred)
13 Likes

Assignments to Variables are in-place operations and you’re doing a lot of them (u_i[:,j,:]). You’re using that variable in lots of other contexts and some of the functions require it to not change. This might help (I added some calls to clone):

u_i[:,j,:] = (1 - self.lumbda)*u_i[:,j,:].clone() + self.lumbda*F.relu(self.u_i_linear(tmp_x))
u_i[:,j,:] = u_i[:,j,:].clone()*comp_mask_ui.float()

alpha_i[:,j,:] = F.tanh(self.alpha_i_linear(torch.cat((u_i[:,j,:], u_s),1)))
alpha_i[:,j,:] = alpha_i[:,j,:].clone()*comp_mask.float()

26 Likes

Thanks, it’s working. But what do you mean by ‘Assignments to Variables are in-place operations’? So something like x=x+1 is in-place operation? Or just because I am using indexing in a matrix?

x = x + 1 is not in-place, because it takes the objects pointed to by x, creates a new Variable, adds 1 to x putting the result in the new Variable, and overwrites the object referenced by x to point to the new var. There are no in-place modifications, you only change Python references (you can check that id(x) is different before and after that line).

On the other hand, doing x += 1 or x[0] = 1 will modify the data of the Variable in-place, so that no copy is done. However some functions (in your case *) require the inputs to never change after they compute the output, or they wouldn’t be able to compute the gradient. That’s why an error is raised.

99 Likes

Thanks! That’s a great explanation!

Nice explanation!!
I encountered the same problem and solved it by the explanation!
Thanks

@apaszke I am facing a similar problem in my code, is there any way I can find which Variable is causing the problem because of in-place operations in my code? I am kind of stuck at this point, any help would be appreciated.

5 Likes

I think you can just try to clone the variable before you use it.

I’m getting same error for a different scenario. I am running a LSTM code and my code works completely fine when I use criterion as MSELoss and it gives me this error when I change my criterion to CrossEntropyLoss (of course I am feeding in desired type of inputs to my criterion). I get this error when I call loss.backward(). Strangely, the code runs perfectly fine when I call loss.backward() at every time step in the time loop instead of calling after the entire sequence has been completed.

Is it possible to have some pointing to a variable at least which is causing this trouble or any other way to reason out the possible error?

Thank you in advance for your help.

1 Like

Hi,

One thing you can do is check the line numbers in the traceback. If you look up the class the line belongs to in torch/autograd/_functions/*.py, you can tell which operation it bails out on to narrow it down.
I wonder whether it might be worth adding a “debug” modus that records the stack of the op in the forward pass and spits it out on error in the backward. That way, it would point to the right line of code directly.

Best regards

Thomas

2 Likes

x11[:,:,0:int(w_f/2),0:int(h_f/2)]= x11[:,:,0:int(w_f/2),0:int(h_f/2)]*xx1[0]
suppose x11 is a auto grad variable. Now my issue is, when i am running the code it say:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
but if i do x11=x11*xx1[0], then no error things go correctly
i tried:
x11[:,:,0:int(w_f/2),0:int(h_f/2)]= x11[:,:,0:int(w_f/2),0:int(h_f/2)].clone()*xx1[0]
still its not working @apaszke

For pytoch 0.4.0, .clone() does not work. Do you have any idea?

theta[:, 0] = 1 - theta[:, 0].clone()
1 Like

Same problem. Have you found the solution?

Hi @YongyiTang92 @apaszke, I tried clone in my following code, but still give me the same error as @YongyiTang92 's. Thanks for your help.

        phi = x_out[:, :, 0] # [B, C]
        phi.clone()
        Batch = phi.shape[0]
        
        for i in range(Batch):        	
        	phi[i] = phi_constant * phi[i].clone() + phi_offset 

actual, using a list to append the values in the for loop solved this problem.

Thank you for the explanation. It’s just quite weird that x+=1 is in-place while x=x+1 is not.
That makes both operations have a different semantic, while I’m used to consider them as equivalent in python.

3 Likes

For anyone else brought here by Google, you can use set_detect_anomaly to find the location of the inplace operation. See this issue I had created on the topic.

5 Likes

Hi apaszke,

I encountered this error in a cuda extension.

import torch
import torch.nn as nn
from torch.autograd import Function
import sys, os
sys.path.insert(0, os.path.abspath("build/lib.linux-x86_64-3.6/"))
import reprelu_cuda, reprelu_cpp

class RepReLUFunction(Function):
    @staticmethod
    def forward(ctx, input, pos_weight, neg_weight):

        pos_weight = pos_weight.contiguous()
        neg_weight = neg_weight.contiguous()

        output = reprelu_cuda.forward(input, pos_weight, neg_weight)[0]
        
        ctx.save_for_backward(input, pos_weight, neg_weight)

        return output

    @staticmethod
    def backward(ctx, grad_output):

        outputs = reprelu_cuda.backward(grad_output.contiguous(), *ctx.saved_variables)

        d_input, d_pos_weight, d_neg_weight = outputs

        return d_input, d_pos_weight, d_neg_weight

class RepReLU(nn.Module):

    def __init__(self, planes, neg_slope=0.25):
        
        super(RepReLU, self).__init__()


    
    def forward(self, x):

        w_pos = torch.nn.several_computations(x)
        w_neg = torch.nn.other_computations(x)

        return RepReLUFunction.apply(x, w_pos, w_neg)

Where reprelu is a cuda extension strictly follow this tutorial.

When running following test code, the error occurs:

data = torch.zeros(1, 2, 2, 2).cuda()
reprelu = RepReLU(2).cuda()

loss = reprelu(data).sum()
loss.backward()

The error messege told me that
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation in

outputs = reprelu_cuda.backward(grad_output.contiguous(), *ctx.saved_variables)

However, reprelu_cuda.backward is implemented in cuda and wrapped in c++, I didn’t make any change to inputs in c++ functions and cuda kernel functions.

Hello, I am experiencing a similar issue, however I do not fully understand the behavior, so I would be grateful for some insight.
When I run the code with these 2 lines of code, it causes the mentioned In-place operation error

attention = nn.functional.softmax(torch.matmul(key.transpose(1, 2), query) / math.sqrt(num_of_channels), dim=1)
attention[:, 0, 0] = torch.ones(attention.shape[0], requires_grad = True)

Error:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

However, if I run this code like this:

attention = nn.functional.softmax(torch.matmul(key.transpose(1, 2), query), dim=1) / math.sqrt(num_of_channels)
attention[:, 0, 0] = torch.ones(attention.shape[0], requires_grad = True)

it works perfectly fine.

What is the reason for such behavior? The resulting tensor has the attribute requires_grad set to True in both cases, only the backward function differs. Actually, what is fun is that I can just use a dummy division by 1.0 to avoid getting an error, but it does not seem elegant at all:

 attention = nn.functional.softmax(torch.matmul(key.transpose(1, 2), query) / math.sqrt(num_of_channels), dim=1)/1.0
attention[:, 0, 0] = torch.ones(attention.shape[0], requires_grad = True)
1 Like

Super useful. It works for me. Thank you so much.