One of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [768, 64]] is at version 3; expected version 2 instead

Hi! I’m pretty new to PyTorch and I’ve come across this error in my network:

one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [768, 64]] is at version 3; expected version 2 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

I have already read several forums on this topic. They all generally seem to say that there is an in place operation someplace in the code that is messing up the backward operation, but I can’t seem to find any in place operations.

Here is my learn function where I call backward:

#updating the q-network. IE, learning from past experiences
def learn(batch, optim, optimGru,optimLilGru, net, target_net, gamma, global_step, target_update):
    #separate batch
    s,a,s_,r,d,a_ = batch #in addtion to the batch I will need to return the next action for the next state
    #zero the grad
    optim.zero_grad()
    optimLilGru.zero_grad()
    optimGru.zero_grad()
    
    ##calculate loss
    x = torch.tensor(s_, dtype = torch.float).cuda()

    Q_hat_mean = target_net.forward(x).mean(axis=1)[0] #We don't want the max, we want something different. I picked me#an() since it isn't the max and would likely give amiddle range result
    true = r + gamma * (1-d) * Q_hat_mean
    
    est2 = net(s)[range(len(a_)),a_] 

    loss = torch.mean((true-est2)**2) #found the MSE
    loss.backward()
    optim.step()
    optimGru.step()
    optimLilGru.step()
    torch.autograd.set_detect_anomaly(True)

    if global_step % target_update == 0:
        target_net.load_state_dict(net.state_dict())

    return loss.item()

I am using two GRU and a DQN in this code. If it is helpful, I can post more code snippets, but I’m not sure which parts will be useful in solving the problem.

I’m looking for a way to resolve the error, but I am also very interested in understanding more about what this error means and how I can avoid it in future projects I work on. Any advice on either of those topics will be greatly appreciated.

Thanks!