Implementing a custom convolution using conv2d_input and conv2d_weight

Hi,
I have been trying to implement a custom convolutional layer.
In order to do that, I’m using torch.nn.functional.conv2d in the forward pass, and both torch.nn.grad.conv2d_weight and torch.nn.grad.conv2d_input in the backward pass.
I started getting OOM exceptions when entering torch.nn.grad.conv2d_weight.

My question is, what exactly is the difference between using:

torch.nn.functional.conv2d(x, w)

and

MyConv().apply(x, w)

when MyConv is implemented as follows:

class MyConv(Function):
   @staticmethod
   def forward(ctx, x, w):
     ctx.save_for_backward(x, w)
     return F.conv2d(x, w)

  @staticmethod
  def backward(ctx, grad_output):
    x, w = ctx.saved_variables
    x_grad = w_grad = None
    if ctx.needs_input_grad[0]:
      x_grad = torch.nn.grad.conv2d_input(x.shape, w, grad_output)
    if ctx.needs_input_grad[1]:
      w_grad = torch.nn.grad.conv2d_weight(x, w.shape, grad_output)
    return x_grad, w_grad

Why would torch.nn.grad.conv2d_weight return an OOM exception when torch.nn.functional.conv2d (that I assume also uses torch.nn.grad.conv2d_weight in the backward pass) did not?

Thanks.

hi , have you solved your problem ? I want to define a conv2d layer too, can you share me you code?

Yes.
I’ve avoided this by directly calling cudnn_convolution_backward_input and cudnn_convolution_backward_weight (by following this example https://github.com/pytorch/extension-cpp, and adding two c++ functions that call the cudnn functions) instead of using torch.nn.grad.conv2d_input and torch.nn.grad.conv2d_weight.
You might not have that problem though (depends on your nn and GPU model). You should first check if the if torch.nn.grad.conv2d_weight and torch.nn.grad.conv2d_input are working for your model without returning an out of memory exception.

2 Likes

I use your code above to modify my code, but I got an error.
Is F.conv2d(x, w) have both forward and backward methords?

I just want to custom a similar conv2d layer as the API
Can I just modify the pytorch python interface?

What exactly is the error?

it shows as follows

File “/home/lth/anaconda3/lib/python3.5/site-packages/torch/nn/functional.py”, line 90, in conv2d
return f(input, weight, bias)

TypeError: argument 0 is not a Variable

my code is like that:
class Conv2d(_ConvNd):

def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
    kernel_size = _pair(kernel_size)
    stride = _pair(stride)
    padding = _pair(padding)
    dilation = _pair(dilation)
    super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, False, _pair(0), groups, bias)

def forward(self, input):       
    return conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)

conv2d = Conv2dF.apply

class Conv2dF(Function):

@staticmethod
def forward(cxt, input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
        
    cxt.save_for_backward(input, weight, bias)

    return F.conv2d(input, weight, bias, stride, padding, dilation, groups)


@staticmethod
def backward(cxt, grad_output):
    input, weight, bias = cxt.saved_variables
            
    grad_input = grad_weight= grad_bias = None

    if cxt.needs_input_grad[0]:
        grad_input = torch.nn.grad.conv2d_input(input.shape, weight, grad_output)
        
    if cxt.needs_input_grad[1]:
        grad_weight = torch.nn.grad.conv2d_weight(input, weight.shape, grad_output)
            
    if bias is not None and cxt.needs_input_grad[2]:
        grad_bias = grad_output.sum(0).squeeze(0)
    
    if bias is not None:
        return grad_input, grad_weight, grad_bias
    else:
        return grad_input, grad_weight

Thanks so much !

It looks like you’re using an old version of pytorch. Try moving to pytorch 0.4.0 and see if it works.
To verify that use:

import torch
print(torch.__version__) #should be 0.4.0

I have just update the pytorch version to 0.4.0,but it is also that error
could you show me your demo code
:sob::sob::sob:

Could you please share your code that you are calling cudnn_convolution_weight

For other people googling this, I posted some code in this thread: Cuda error with cudnn convolution backward weight function

1 Like

Hi, This OOM exception comes from the python api implement of conv2d_weight actually.
In backprop weight calculation, the output gradients need to be expanded with output channel times. When default cudnn implement this with data prefetch block and block (not allocate more memory), python api uses a repeat that will allocate a huge size of memory on output gradients tensor with unnecessary duplication of data.
you can easily fix this by convert the repeat into a loop function at conv2d_weight.

I think the expression for grad_bias should be fixed to:

grad_bias = grad_output.sum((0,2,3)).squeeze(0)