RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation:

So I have the following code:

import torch
import torch.nn as nn
import math

torch.autograd.set_detect_anomaly(True)

def sigmoid(x):
    return 1 / (1 + torch.exp(-x))

dtype = torch.float
device = torch.device("cpu")
# device = torch.device("cuda:0") # Uncomment this to run on the cpu

# N = batch size
# D_in = input size
# H = hidden dimension
# D_out = output dimension
Time, N, D_in, H, D_out = 7, 3, 10, 5, 2

# Create random Tensors to hold input and outputs.
# Setting requires_grad=False indicates that we do not need to compute gradients
# with respect to thes Tensors during the backward pass.
X = torch.randn(Time, N, D_in, device=device, dtype=dtype, requires_grad=False)
Y = torch.randn(Time, D_out, D_in, device=device, dtype=dtype, requires_grad=False)
Cell = torch.zeros(Time, H, D_in, device=device, dtype=dtype, requires_grad=False)
Hidden = torch.zeros(Time, H, D_in, device=device, dtype=dtype, requires_grad=False)
yHat = torch.zeros(Time, D_out, D_in, device=device, dtype=dtype, requires_grad=False)

# Create random Tensors for weights.
# Setting requires_grad=True indicates that we want to compute gradients with 
# respect to these Tensors during the backward pass.
WeightsForget = torch.randn(H, H+N, device=device, dtype=dtype, requires_grad=True)
WeightsUpdate = torch.randn(H, H+N, device=device, dtype=dtype, requires_grad=True)
WeightsCell = torch.randn(H, H+N, device=device, dtype=dtype, requires_grad=True)
WeightsOutput = torch.randn(H, H+N, device=device, dtype=dtype, requires_grad=True)

# Create random Tensors for Bias.
BiasForget = torch.randn(H, 1, device=device, dtype=dtype, requires_grad=True)
BiasUpdate = torch.randn(H, 1, device=device, dtype=dtype, requires_grad=True)
BiasCell = torch.randn(H, 1, device=device, dtype=dtype, requires_grad=True)
BiasOutput = torch.randn(H, 1, device=device, dtype=dtype, requires_grad=True)

# Weights and Bias relating the hidden state to the output.
Weights_HiddenToOutput = torch.randn(D_out, H, device=device, dtype=dtype, requires_grad=True)
Bias_HiddenToOutput = torch.randn(D_out, 1, device=device, dtype=dtype, requires_grad=True)

# Initialize aNext and cNext
hLast = Hidden[0,:,:]
cLast = torch.zeros(hLast.shape)

learning_rate = 1e-6
for t in range(Time):

    # Concat hLast and X(t)
    concat = torch.zeros(H+N, D_in, device=device, dtype=dtype)
    concat[: H, :] = hLast
    concat[H :, :] = X[t,:,:] 

    forget = sigmoid(torch.mm(WeightsForget, concat) + BiasForget) 
    update = sigmoid(torch.mm(WeightsUpdate, concat) + BiasUpdate)
    candCell = torch.tanh(torch.mm(WeightsCell, concat) + BiasCell)
    cNext = forget * cLast + update * candCell
    output = sigmoid(torch.mm(WeightsOutput, concat) + BiasOutput)
    hNext = output * torch.tanh(cNext)

    # Compute prediction of the LSTM Cell
    prediction = torch.mm(Weights_HiddenToOutput, hNext) + Bias_HiddenToOutput

    # Save the value of the next hidden state
    Hidden[t,:,:] = hNext

    # Save the value of the Prediction in Y
    yHat[t,:,:] = prediction

    # Save the value of the next Cell State
    Cell[t,:,:] = cNext
  
    # Compute and print loss using operations on Tensors.
    # Now loss is a Tensor of shape(1,)
    loss = (-Y[t,:,:] * torch.log(prediction) - (1-Y[t,:,:]) * torch.log(1-prediction)).sum()
    print('t, loss.item = ', t, loss.item()) 
    loss.backward(retain_graph=True)

    with torch.no_grad():
        # Update hLast and cLast
        hLast = hNext
        cLast = cNext
        
        # Update Weights
        WeightsCell -= learning_rate * WeightsCell.grad
        WeightsOutput -= learning_rate * WeightsOutput.grad
        WeightsUpdate -= learning_rate * WeightsUpdate.grad
        Weights_HiddenToOutput -= learning_rate * Weights_HiddenToOutput.grad

        # Update Bias
        BiasCell -= learning_rate * BiasCell.grad
        BiasForget -= learning_rate * BiasForget.grad
        BiasOutput -= learning_rate * BiasOutput.grad
        BiasUpdate -= learning_rate * BiasUpdate.grad
        Bias_HiddenToOutput -= learning_rate * Bias_HiddenToOutput.grad
        
        # Zero out the gradients
        WeightsCell.grad.zero_()
        WeightsForget.grad.zero_()
        WeightsOutput.grad.zero_()
        WeightsUpdate.grad.zero_()
        Weights_HiddenToOutput.grad.zero_()
        BiasCell.grad.zero_()
        BiasForget.grad.zero_()
        BiasOutput.grad.zero_()
        BiasUpdate.grad.zero_()
        print(Bias_HiddenToOutput.dtype)
        Bias_HiddenToOutput.grad.zero_()

When I run it I get:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [5, 8]] is at version 1; expected version 0 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

Commenting out the updates of the weights and biases will let the code run but doesn’t update the weights and biases. I’m new to pytorch and I understand that this error has something to do with the inplace operation incrementing the version of the Tensor during the loss.backward() call. But I thought the with torch.no_grad(): was supposed to let me update the Tensors without incrementing their version.

So I guess I have 2 questions… First, how do I get my code running?? Second, what am I missing about using with torch.no_grad()?? Because it doesn’t work the way I thought.

Thanks!

JP

Hi,

The torch.no_grad() hides these ops from the autograd engine so that when you call .backward(), they won’t be taken into account. So this is the right use here.
The error you’re seeing is different: when the forward pass needs some Tensor to be able to compute the backward pass, this Tensor is saved. But if the user overrides the content of that Tensor we saved, then we can’t compute the backward pass anymore. (the fact that you override it in a differentiable manner or not does not matter here. The value that we need does not exist anymore).

I think the issue here is that you’re missing a detach() when you do hLast = hNext.
Unfortunately, because there is no operation done here, this does not work nicely with the torch.no_grad() :confused:
In python, doing this only assigns the same python object to a new name. So no pytorch op is called and we cannot detect this automatically…
You will need to explicitly .detach() here or do any op like hLast = hNext.view_as(hNext) or hLast = hNext.clone() .

Thanks!! That was correct for that one. But I still don’t understand the ‘with torch.no_grad()’ function works. On line 64 in the code below I do an inplace operation under the no_grad() function an it increments Weights_HiddenToOutput._version from 0 to 1. I thought it wasn’t supposed to do that under the no_grad() function. Why does that happen??

(Sorry about the strange code formatting. I’m new at this and I can’t figure out how to put all the code in one block quote.)

import torch
torch.autograd.set_detect_anomaly(True)
def sigmoid(x):
    return 1/(1+torch.exp(-x))

dtype = torch.float
device = torch.device("cpu")
Time, N, D_in, H, D_out = 7, 3, 10, 5, 2
X = torch.randn(Time, N, D_in, device=device, dtype=dtype, requires_grad=False)
Y = torch.randn(Time, D_out, D_in, device=device, dtype=dtype, requires_grad=False)
yHat = torch.randn(Time, D_out, D_in, device=device, dtype=dtype, requires_grad=False)
Hidden = 5
WeightsForget = torch.randn(H, H+N, device=device, dtype=dtype, requires_grad=True)
WeightsForget = torch.randn(H, H+N, device=device, dtype=dtype, requires_grad=True)
WeightsUpdate = torch.randn(H, H+N, device=device, dtype=dtype, requires_grad=True)
WeightsCell = torch.randn(H, H+N, device=device, dtype=dtype, requires_grad=True)
WeightsOutput = torch.randn(H, H+N, device=device, dtype=dtype, requires_grad=True)
BiasForget = torch.randn(H, 1, device=device, dtype=dtype, requires_grad=True)
BiasUpdate = torch.randn(H, 1, device=device, dtype=dtype, requires_grad=True)
BiasCell = torch.randn(H, 1, device=device, dtype=dtype, requires_grad=True)
BiasOutput = torch.randn(H, 1, device=device, dtype=dtype, requires_grad=True)
Weights_HiddenToOutput = torch.randn(D_out, H, device=device, dtype=dtype, requires_grad=True)
Bias_HiddenToOutput = torch.randn(D_out, 1, device=device, dtype=dtype, requires_grad=True)
Hidden = torch.zeros(Time, H, D_in, device=device, dtype=dtype, requires_grad=False)
Cell = torch.zeros(Time, H, D_in, device=device, dtype=dtype, requires_grad=False)
hLast = torch.zeros(H, D_in, device=device, dtype=dtype, requires_grad=False)
Hidden[-1] = hLast
cLast = torch.zeros(hLast.shape)
Cell[-1] = cLast
learning_rate = 1e-6

for e in range(100):
    for t in range(len(X)):
        forget = sigmoid(torch.mm(WeightsForget, torch.cat((X[t,:,:],hLast),0)) + BiasForget)
        update = sigmoid(torch.mm(WeightsUpdate, torch.cat((X[t,:,:],hLast),0)) + BiasUpdate)
        candCell = torch.tanh(torch.mm(WeightsCell, torch.cat((X[t,:,:],hLast),0)) + BiasCell)
        cNext = forget * cLast + update * candCell
        output = sigmoid(torch.mm(WeightsOutput, torch.cat((X[t,:,:],hLast),0)) + BiasOutput)
        hNext = output * torch.tanh(cNext)

        prediction = torch.softmax((torch.mm(Weights_HiddenToOutput, hNext) + Bias_HiddenToOutput),0)

        Hidden[t] = hNext
        hLast = hNext.clone()
        Cell[t] = cNext
        cLast = cNext.clone()
        yHat[t] = prediction

    loss = (yHat - Y).pow(2).sum()
    loss.backward(retain_graph=True)

    with torch.no_grad():
        # Update Weights
        WeightsCell.data -= learning_rate * WeightsCell.grad
        WeightsOutput.data -= learning_rate * WeightsOutput.grad
        print('WeightsUpdate._version = ', WeightsUpdate._version)
        WeightsUpdate.data -= learning_rate * WeightsUpdate.grad
        print('WeightsUpdate._version = ', WeightsUpdate._version)
        WeightsForget.data -= learning_rate * WeightsForget.grad
        print('Weights_HiddenToOutput.shape = ', Weights_HiddenToOutput.shape)
        print('Weights_HiddenToOutput._version = ', Weights_HiddenToOutput._version)
        Weights_HiddenToOutput -= learning_rate * Weights_HiddenToOutput.grad
        print('Weights_HiddenToOutput._version = ', Weights_HiddenToOutput._version)
        
        # Update Bias
        BiasCell -= learning_rate * BiasCell.grad
        BiasForget -= learning_rate * BiasForget.grad
        BiasOutput -= learning_rate * BiasOutput.grad
        BiasUpdate -= learning_rate * BiasUpdate.grad
        Bias_HiddenToOutput -= learning_rate * Bias_HiddenToOutput.grad
        
        # Zero out the gradients
        WeightsCell.grad.zero_()
        WeightsForget.grad.zero_()
        WeightsOutput.grad.zero_()
        WeightsUpdate.grad.zero_()
        Weights_HiddenToOutput.grad.zero_()
        BiasCell.grad.zero_()
        BiasForget.grad.zero_()
        BiasOutput.grad.zero_()
        BiasUpdate.grad.zero_()
        Bias_HiddenToOutput.grad.zero_()

Hi,

For code formatting you can use triple backticks ``` before and after. I edited your post, you can take a look by editing again to see what the raw text looks like.

I thought it wasn’t supposed to do that under the no_grad() function. Why does that happen??

It depends on what else you’re doing. If you don’t do any op that saves the value of that Tensor, it is fine to modify it inplace (this is the case here).
If you do require the original value (like in your original post), then the autograd engine will prevent you from doing that.

Note that torch.no_grad is irrelevant here. And the check happens the same way with or without it (even though the computed gradients will be different of course).

So if it’s irrelevant here, Why am I still getting the error:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [2, 5]] is at version 1; expected version 0 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

Because you modify inplace a Tensor that is needed for backward.
So you should use anomaly mode to get more informations. And remove the corresponding inplace (or add a clone() to make sure the saved Tensor itself is not modified).

Forget it I just figured out what my problem is. Line 64 should be:

Weights_HiddenToOutput.data -= learning_rate * Weights_HiddenToOutput.grad

Not:

Weights_HiddenToOutput -= learning_rate * Weights_HiddenToOutput.grad

Thanks for your help though.

1 Like