Why in-place operations on Variable data has no effects on backward?

The test code snippet shown as follows (Pytorch v0.3.1):

import torch
import torch.nn as nn
from torch.autograd import Variable
m = nn.Conv2d(1,1,kernel_size=3, stride=1, padding=1, bias=False)
m.weight.data.fill_(1)
a = torch.ones(1,1,5,5)
input_var = Variable(a, requires_grad=True)
output = m(input_var)
m.zero_grad()
input_var.data.fill_(0)
output.sum().backward()
print(m.weight.grad.data, input_var.grad.data)

The output:

(0 ,0 ,.,.) = 
  16  20  16
  20  25  20
  16  20  16
[torch.FloatTensor of size 1x1x3x3]
 
(0 ,0 ,.,.) = 
  4  6  6  6  4
  6  9  9  9  6
  6  9  9  9  6
  6  9  9  9  6
  4  6  6  6  4
[torch.FloatTensor of size 1x1x5x5]

The magic thing is that the in-place operation input_var.data.fill_(0) doesn’t affect the gradient of convolutional weights. Then what data is in use when backward?

Based on my understanding, computing the gradient of conv weights relies on the input variable. And the in-place operation on leaf Variable is not permitted (input_var.zero_()). It drives me to think the mechanism of autograd. Could anyone explain it? Thanks a lot.

Update: the input data change with cuda would affect the gradients.

When you do operations on Variables, PyTorch keeps track of the computation graph in order to be able to backpropagate.

Inplace operations are non-differentiable. That is why var.zero_() gives an error.

For a variable var, the underlying data is stored in a tensor that is accessible via var.data. If you do an operation on var.data PyTorch does not add the operation to the computation graph.

I think what happened in your case is that the forward method of nn.Conv2d saved some partial results to be used in the backward pass. I think that is why it didn’t use the updated value of input_var.

1 Like

Thanks. I agree that it is very likely that nn.Conv2d works that way. Maybe the case I give is not general for other autograd functions. Anyway, I still hope that someone can give a more solid explanation.

Hi,

Yes the point here is that the convolution does not save the input but an intermediary result (namely the version of the input that has been preprocessed so that convolution becomes a MM).
For example if you do a * b and try the same thing as above, you will see that the gradient is influenced by changing a.data.

As a general idea, changing .data of a Variable after it’s been used is going to lead to weird / wrong result so that should not be done ! Do you have a use case where you want to do it?

Thanks, I understand it now. Yes, I really want to use it for the sequential like BN-ReLU-Conv. As the backward of convolutional layer does not need the input data, I can make the input data stored in a shared memory to save memory usage. For BN-ReLU backward, I just need to recompute the output first. By the way, does the backward of the group convolution (namely groups>1) also not need the input?

Wouldn’t it be simpler just to let the python variable that points to the input go out of scope and wait for it to be garbage collected?

If you try to manage memory in tricksy ways then you can expect lots of headaches in the future. For example any update to pytorch could break your code if you rely on undocumented features like the input data not being used to calculate gradients.

You mean that since Conv backward doesn’t need the input data, the autograd graph will automatically release the input data after forward?

Yes. After the input variable goes out of scope, whatever can be freed will be freed when Python’s GC kicks in.