Custom Autograd Function Backward pass not Called

So far, I’ve check all these posts that I found on the issue, but None of them seems to fit my case.

Here is my code

class _Conv2dCustom(autograd.Function):

  # Kernel has to be of odd size
  @staticmethod
  def forward(ctx, input, weight, bias=None, padding=0):
    print('########## _Conv2dCustom.forward')
    input = input.clone() 
    ctx.save_for_backward(input, weight, bias)
    padding = (padding, padding) if type(padding) is int else padding
    # input = input.clone()
    output = _Conv2dCustom._conv2d(input, weight, padding)
    if bias is not None:
      output += bias.view(-1, 1, 1)
    return output
    
  @staticmethod
  def backward(ctx, grad_output):
    print('########## _Conv2dCustom.backward')
    # Kernel has to be of odd size
    input, weight, bias = ctx.saved_tensors
    grad_input = grad_weight = grad_bias = None
    if ctx.needs_input_grad[0]:
      # With padding, conv produce same size grad_output
      padding = weight.size(-1)//2
      # Difference to produce same size as input
      padding += (grad_output.size(-1)-input.size(-1))//2
      weight_t = torch.transpose(weight, -1, -2)
      grad_input = _Conv2dCustom._conv2d(grad_output, weight_t, padding)
    if ctx.needs_input_grad[1]:
        grad_weight = _Conv2dCustom._conv2d(input, grad_output, padding_used)
    if bias is not None and ctx.needs_input_grad[2]:
        grad_bias = grad_output.sum([0, 2, 3])

    return grad_input, grad_weight, grad_bias

  @staticmethod
  def _conv2d(input, weight, padding=(0, 0), stride=(1, 1)):
    output_h = (input.size(-2) - weight.size(-2) + 2*padding[0])//stride[0] + 1
    output_w = (input.size(-1) - weight.size(-1) + 2*padding[1])//stride[1] + 1

    Xunfold = F.unfold(input, weight.size()[2:], padding=padding, stride=stride)
    weight_flat = weight.view(weight.size(0), -1)
    conv = weight_flat@Xunfold
    output = conv.view(input.size(0), weight.size(0), output_h, output_w)
    return output

class Conv2dCustom(nn.Conv2d):
  # Torch master as in Feb 23rd 2020.
  def _conv_forward(self, *args, **kwargs):
    raise NotImplementedError('This Method Is Not Implemented')

  # Torch v1.4.0  
  def conv2d_forward(self, input, weight):
    conv2d = _Conv2dCustom.apply
    conv = conv2d(input, weight, self.bias, self.padding)
    return conv

Output

==========
Started Training
########## _Conv2dCustom.forward
########## _Conv2dCustom.forward
########## _Conv2dCustom.forward
########## _Conv2dCustom.forward
########## _Conv2dCustom.forward
########## _Conv2dCustom.forward
########## _Conv2dCustom.forward
########## _Conv2dCustom.forward
==========
Finished Training

In my Network, I have two of these Conv2dCustom layers, and I also put hooks on each layer. I noticed that the backward hook for all my layers, except the most inner Conv2dCustom layer.

Can you share _Conv2dCustom._conv2d() ?

Also can you write a small code sample that uses the Conv2dCustom that triggers this so that we can reproduce locally?

_Conv2dCustom._conv2d() is in the code above, in the question. But here it is.

  @staticmethod
  def _conv2d(input, weight, padding=(0, 0), stride=(1, 1)):
    output_h = (input.size(-2) - weight.size(-2) + 2*padding[0])//stride[0] + 1
    output_w = (input.size(-1) - weight.size(-1) + 2*padding[1])//stride[1] + 1

    Xunfold = F.unfold(input, weight.size()[2:], padding=padding, stride=stride)
    weight_flat = weight.view(weight.size(0), -1)
    conv = weight_flat@Xunfold
    output = conv.view(input.size(0), weight.size(0), output_h, output_w)
    return output

Reproducing with a small test

Code

image = np.random.randint(255, size=(1, 1, 32, 32)).astype('float32')
image = torch.from_numpy(image).requires_grad_(True)
layer = Conv2dCustom(1, 1, 5)
result = layer.forward(image)

or

image = np.random.randint(255, size=(1, 1, 32, 32)).astype('float32')
image = torch.autograd.Variable(torch.from_numpy(image)).requires_grad_(True)
layer = Conv2dCustom(1, 1, 5)
result = layer.forward(image)

Output

########## _Conv2dCustom.forward

Code

dout= np.random.randint(255, size=(1, 1, 28, 28)).astype('float32')
dout = torch.from_numpy(dout)
result.backward(dout)

Nothing is printed.

Sorry I missed your conv2d as a static function, my bad.

Few details:

  • You nn.Module does not work for me and through the not implemented error from _conv_forward
  • You should never call the .forward() of the nn.Module directly. You should use the __call__ on it as result = layer(image).
  • You don’t need to wrap things in Variables anymore. You can just use regular Tensors: image = torch.from_numpy(image).requires_grad_().

By removing the custom module, it runs.
This is a know issue and if you use the nightly build, you will see the following warning:
“”"
…/torch/csrc/autograd/variable.cpp:421: UserWarning: Output 0 of _Conv2dCustomBackward is a view and its base or another view of its base has been modified inplace. This view was created inside a custom Function (or because an input was returned as-is) and the autograd logic to handle view+inplace would override the custom backward associated with the custom Function, leading to incorrect gradients. This behavior is deprecated and will be forbidden starting version 1.6. You can remove this warning by cloning the output of the custom Function.
“”"

As it states, the fact that your custom Function returns a view and that you modify it inplace in when adding the bias break some internal autograd assumptions.
You should either change _conv2d to return output.clone() to avoid returning a view. Or change your bias update to output = output + bias.view(-1, 1, 1) to avoid the inplace operations.

1 Like

No worries, that happens to me a lot


I am working on Colab, which uses the v1.4.0. I put that exception to get the error in case the update the current stable version.


Oh I see. Thanks. On my original code, the one I first experimented this issue, I don’t call forward() myself. I use the class from standard torch examples of a neural network (Lenet) with a few changes.


I suspected. Thanks for the clarification. I did both cases just to be safe, as a quick test.


Wow. I really looked for inplace operations. At first, I thought output += bias.view()
could be inplace. But, for some reason, I sticked with “Well, that should be the same thing as output = output+bias”. I totally should had tried that!


Thank you so much! You saved my week! I really appreciate this community!


Just one more question, from the Extending Pytorch page:

Shouldn’t the inplace operation be an issue only for the input parameters on the forward pass?

  • mark_dirty() must be used to mark any input that is modified inplace by the forward function.
1 Like

This is a different thing. If an input is modified inplace, we need to handle it specifically and so we need the user to tell us they did it.

But that logic that handle inplace needs to look at how the view was created. And unfortunately, can only handle a subset of these (they are fairly corner case). We will now raise a nice error if you do so.

1 Like