One of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [1, 2, 28, 2]], which is output 0 of SliceBackward, is at version 14; expected version 13 instead

Hello everyone, i am just a beginner in Pytorch, recently i try to adjust STN to make it works on fish eye image, but i am trapped in problem of “one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [1, 2, 28, 2]], which is output 0 of SliceBackward, is at version 14; expected version 13 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient.” when debugging, i try to fix it with some suggestion here but it still doesn’t work, could someone point out the problem in my code for me, thanks :slight_smile:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

        # Spatial transformer localization-network
        self.localization = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=7),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True),
            nn.Conv2d(8, 10, kernel_size=5),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True)
        )

        # Regressor for the 3 * 2 affine matrix
        self.fc_loc = nn.Sequential(
            nn.Linear(10 * 3 * 3, 2*14),
            # nn.ReLU(True),
            # nn.Linear(32, 3 * 2)
        )

        # Initialize the weights/bias with identity transformation
        self.fc_loc[0].weight.data.zero_()
        self.fc_loc[0].bias.data.copy_(torch.tensor([1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0], dtype=torch.float))

    # Spatial transformer network forward function
    def stn(self, x):
        xs = self.localization(x)
        xs = xs.view(-1, 10 * 3 * 3)
        theta = self.fc_loc(xs)
        theta = theta.view(-1, 14, 2)
        # print('theta size:',theta.size())
        output_2 = torch.zeros(x.size())
        print(type(output_2))

        #grid = torch.zeros(14, 1, 1, 2, 28,2)
        grid = torch.zeros(14, 1, 2, 28, 2)

        beta = torch.zeros(14,2,3)
        for i in range(14):
          beta[i,0,0] = theta[0,i,0]
          beta[i,0,2] = theta[0,i,1]

        for i in range(14):

          # print('grid_size:',grid[i,:,:,:,:].size())
          grid[i,:,:,:,:] = F.affine_grid(beta[i,:,:].unsqueeze(0), [1, 1, 2, 28]).clone()

          output_2[:,:,2*i:2*i+2,:] = F.grid_sample(x[:,:,2*i:2*i+2,:], grid[i,:,:,:,:]).clone()

        return output_2

    def forward(self, x_temp):
        # transform the input
        x = self.stn(x_temp)

        # Perform the usual forward pass
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

Hi,

The error states that you are modifying inplace a Tensor that you are slicing. So an op like your_tensor[ind] = xxx.
It states that because the original value of this Tensor is needed for backward computations, this is not valid.
You should replace this by and out of place operations.

In particular, in your last for-loop, you don’t mneed to do grid[i,:,:,:,:] = xxx (an inplace op) you can simply store that in a temporary Tensor and give it to the next line grid_elem = xxx.
If that does not fix the issue.
For the grid sample that fills in output_2, I usually store all the results in a list and then torch.cat() the result to get the final result.

nice! I try your first suggestion and change my code as “output_2[:,:,2i:2i+2,:] = F.grid_sample(x[:,:,2i:2i+2,:], F.affine_grid(beta[i,:,:].unsqueeze(0), [1, 1, 2, 28]))”, which avoids in-place operation here, and now I can continue my working. Thanks a lot!

1 Like