Implementing a custom convolution using conv2d_input and conv2d_weight

What exactly is the error?

it shows as follows

File “/home/lth/anaconda3/lib/python3.5/site-packages/torch/nn/functional.py”, line 90, in conv2d
return f(input, weight, bias)

TypeError: argument 0 is not a Variable

my code is like that:
class Conv2d(_ConvNd):

def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
    kernel_size = _pair(kernel_size)
    stride = _pair(stride)
    padding = _pair(padding)
    dilation = _pair(dilation)
    super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, False, _pair(0), groups, bias)

def forward(self, input):       
    return conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)

conv2d = Conv2dF.apply

class Conv2dF(Function):

@staticmethod
def forward(cxt, input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
        
    cxt.save_for_backward(input, weight, bias)

    return F.conv2d(input, weight, bias, stride, padding, dilation, groups)


@staticmethod
def backward(cxt, grad_output):
    input, weight, bias = cxt.saved_variables
            
    grad_input = grad_weight= grad_bias = None

    if cxt.needs_input_grad[0]:
        grad_input = torch.nn.grad.conv2d_input(input.shape, weight, grad_output)
        
    if cxt.needs_input_grad[1]:
        grad_weight = torch.nn.grad.conv2d_weight(input, weight.shape, grad_output)
            
    if bias is not None and cxt.needs_input_grad[2]:
        grad_bias = grad_output.sum(0).squeeze(0)
    
    if bias is not None:
        return grad_input, grad_weight, grad_bias
    else:
        return grad_input, grad_weight

Thanks so much !

1 Like

It looks like you’re using an old version of pytorch. Try moving to pytorch 0.4.0 and see if it works.
To verify that use:

import torch
print(torch.__version__) #should be 0.4.0

I have just update the pytorch version to 0.4.0,but it is also that error
could you show me your demo code
:sob::sob::sob:

Could you please share your code that you are calling cudnn_convolution_weight

For other people googling this, I posted some code in this thread: Cuda error with cudnn convolution backward weight function

1 Like

Hi, This OOM exception comes from the python api implement of conv2d_weight actually.
In backprop weight calculation, the output gradients need to be expanded with output channel times. When default cudnn implement this with data prefetch block and block (not allocate more memory), python api uses a repeat that will allocate a huge size of memory on output gradients tensor with unnecessary duplication of data.
you can easily fix this by convert the repeat into a loop function at conv2d_weight.

3 Likes

I think the expression for grad_bias should be fixed to:

grad_bias = grad_output.sum((0,2,3)).squeeze(0)

Thanks for @hanspinckaers for sharing this. I managed to run the code, however it is still slow. Much slower that using Pytorch’s conv2d and letting autograd do the work.

Hi Mostafa, in my benchmarks it does seem to perform at equal speed. how did you benchmark it?

If you mean running cudnn_convolution_backward_input() and cudnn_convolution_backward_weight() is slower than calling conv2d and letting autograd do the back-propagation, then it makes sense because you’re now calling two functions separately, and in addition (based on your previous comment) calculating the grad_bias.
If you want to override the whole back-propagation process of Conv2d and still have the same processing time, you should use the combined cudnn_convolution_backward() that returns gradients w.r.t the input, gradients w.r.t the weights and gradients w.r.t the biases in that order.

The question here and @hanspinckaers solution refer to overriding only torch.nn.grad.conv2d_weight, which is very expensive in memory, with cudnn_convolution_backward_weight().

Thanks @fsds and @hanspinckaers!

@hanspinckaers: Thank you following up. it has been a while since I did the benchmark. I recall I was training ResNet18 on Imagenet. Using Pytorch’s torch.nn.grad.conv2d_input(...) and torch.nn.grad.conv2d_weight(...) was probably twice as slow and using twice as much memory than letting PyTorch derive the backward pass of Conv2d automatically.
When I tried to use the method you provided in this link, made things a bit faster, but still much slower than PyTorch’s automatic backward pass.

@fsds: Thanks for your answer. Are you referring to td::tuple<at::Tensor,at::Tensor> cudnn_convolution_backward(...) in:


So I just need to create a Python wrapper to it and invoke it in our backward pass?

Yes, that is what I’m referring to, although something is a little weird here, because in Pytorch1.2 (the version I’m currently using) this function returns 3 tensors (std::tuple<at::Tensor,at::Tensor,at::Tensor>, grad_output, grad_weight, grad_bias) and in the current state of the repository it returns only 2 (std::tuple<at::Tensor,at::Tensor>, grad_output and grad_weight).
You can see here the implementation of 1.2:

Either way, using this should give you better performance.

We removed the bias, as the backward pass was faster with native PyTorch ops and had some other advantages as seen in this PR.

1 Like

This doesn’t works if the stride and padding is different from basic values, so I’ve edited a bit.

ValueError: requested an input grad size of [4, 4], but valid sizes range from [6, 6] to [6, 6] (for a grad_output of torch.Size([4, 4]))

So, I saved some arguments with save_for_backward to work with it.

class Conv2dFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
        # Save arguments to context to use on backward
        # WARNING : if stride, padding, dilation etc is array, this will not work properly!!!!
        confs = torch.from_numpy(np.array([stride, padding, dilation, groups]))
        ctx.save_for_backward(input, weight, bias, confs)

        # Compute Convolution
        return F.conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
    
    @staticmethod
    def backward(ctx, grad_output):
        # Load saved tensors
        input, weight, bias, confs = ctx.saved_variables
        confs = confs.numpy()
        stride, padding, dilation, groups= confs[0], confs[1], confs[2], confs[3]

        # Calculate Gradient
        grad_input = grad_weight = grad_bias = None
        if ctx.needs_input_grad[0]:
            grad_input = torch.nn.grad.conv2d_input(input.shape, weight, grad_output, stride, padding, dilation, groups)
            
        if ctx.needs_input_grad[1]:
            grad_weight = torch.nn.grad.conv2d_weight(input, weight.shape, grad_output, stride, padding, dilation, groups)
                
        # WARNING : Bias maybe buggy, remove if it is buggy
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0).squeeze(0)


        # WARNING : Bias maybe buggy, remove if it is buggy
        if bias is not None:
            return grad_input, grad_weight, grad_bias, None, None, None, None
        else:
            return grad_input, grad_weight, None, None, None, None, None

Since there are more stride, padding, etc on the forward input, there need more output None to calculate grad.

Is works fine, but is there any elegant way to do this without returning useless None(s)?

Also, I think conv with bias is buggy. I’ll fix if I have more spare time

I don’t think there is a way around returning ‘None’ gradients to constant values.
The grad_bias in your code works only in cases where the grad_output has a shape of [B,C,1,1]. You should probably replace that with grad_bias=grad_output.sum(dim=(0,2,3)) for it to work properly for every Conv2d shape.

1 Like

Thanks! I was struggling with that issue actually.

Thanks to all the other responses. Here is the memory-efficient, fast solution:

from torch.utils.cpp_extension import load

cudnn_convolution = load(name="cudnn_convolution", sources=["cudnn_convolution.cpp"], verbose=True)


class CustomConv2d(Function):
    @staticmethod
    def forward(ctx, input, weight, bias, stride, padding, dilation, groups):
        ctx.save_for_backward(input, weight, bias)
        ctx.conf = {
            "stride": stride,
            "padding": padding,
            "dilation": dilation,
            "groups": groups
        }
        return cudnn_convolution.convolution(input, weight, bias, stride, padding, dilation, groups,
                                             False, False)

    @staticmethod
    def backward(ctx, grad_output):
        input, weight, bias = ctx.saved_variables
        conf = ctx.conf
        input_grad = weight_grad = bias_grad = stride_grad = padding_grad = dilation_grad = groups_grad =  None
        if ctx.needs_input_grad[0]:
            input_grad = cudnn_convolution.convolution_backward_input(input.shape, weight, grad_output, conf["stride"],
                                                                      conf["padding"], conf["dilation"], conf["groups"],
                                                                      False, False, False)
        if ctx.needs_input_grad[1]:
            weight_grad = cudnn_convolution.convolution_backward_weight(input, weight.shape, grad_output,
                                                                        conf["stride"], conf["padding"],
                                                                        conf["dilation"], conf["groups"],
                                                                        False, False, False)

        if bias is not None and ctx.needs_input_grad[2]:
            bias_grad = grad_output.sum(dim=(0, 2, 3))

        return input_grad, weight_grad, bias_grad, stride_grad, padding_grad, dilation_grad, groups_grad

For the above code to work, you need to create a file named “cudnn_convolution.cpp” which containts the following code(copied from this repo PyTorch cuDNN Convolution):

#include <torch/extension.h>
#include <vector>
#include <ATen/NativeFunctions.h>
#include <ATen/Config.h>

/*
PyTorch extension enabling direct access to the following cuDNN-accelerated C++ functions
that are included in PyTorch:
    - cudnn_convolution
    - cudnn_convolution_backward_weight
    - cudnn_convolution_backward_input
The functions defined here can be called from Python in replacement of
torch.nn.conv2d, torch.nn.grad.conv2d_weight and torch.nn.grad.conv2d_input,
and run significantly faster. See 'example.py' for how these functions
are called.
Adapted from code posted by hanspinckaers:
https://discuss.pytorch.org/t/cuda-error-with-cudnn-convolution-backward-weight-function/41214
*/

at::Tensor convolution(
    const at::Tensor& input,
    const at::Tensor& weight,
    const at::Tensor& bias,
    c10::ArrayRef<int64_t> stride,
    c10::ArrayRef<int64_t> padding,
    c10::ArrayRef<int64_t> dilation,
    int64_t groups,
    bool benchmark,
    bool deterministic) {

    return at::cudnn_convolution(
        input,
        weight,
        bias,
        padding,
        stride,
        dilation,
        groups,
        benchmark,
        deterministic); 
}

at::Tensor convolution_backward_weight(
    const at::Tensor& input,
    c10::ArrayRef<int64_t> weight_size,
    const at::Tensor& grad_output,
    c10::ArrayRef<int64_t> stride,
    c10::ArrayRef<int64_t> padding,
    c10::ArrayRef<int64_t> dilation,
    int64_t groups,
    bool benchmark,
    bool deterministic,
    bool allow_tf32) {

    return at::cudnn_convolution_backward_weight(
        weight_size,
        grad_output,
        input,
        padding,
        stride,
        dilation,
        groups,
        benchmark,
        deterministic,
        allow_tf32);
}

at::Tensor convolution_backward_input(
    c10::ArrayRef<int64_t> input_size,
    const at::Tensor& weight,
    const at::Tensor& grad_output,
    c10::ArrayRef<int64_t> stride,
    c10::ArrayRef<int64_t> padding,
    c10::ArrayRef<int64_t> dilation,
    int64_t groups,
    bool benchmark,
    bool deterministic,
    bool allow_tf32) {

    return at::cudnn_convolution_backward_input(
        input_size,
        grad_output,
        weight,
        padding,
        stride,
        dilation,
        groups,
        benchmark,
        deterministic,
        allow_tf32);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("convolution", &convolution, "convolution");
    m.def("convolution_backward_weight", &convolution_backward_weight, "convolution backward weight");
    m.def("convolution_backward_input", &convolution_backward_input, "convolution backward input");
}
3 Likes

Thanks for sharing this code.

I am trying to find the documentation or definition of this function at::cudnn_convolution. Is it available anywhere?

What is benchmark for?

If you are looking for the intuition behind using at::cudnn_convolution, you can look here. The exact definition of it may vary based on the version of cuDNN, for v7 look here, and for v8 here.

Here is a detailed explanation about benchmark.

1 Like