Implementing a custom convolution using conv2d_input and conv2d_weight

I don’t think there is a way around returning ‘None’ gradients to constant values.
The grad_bias in your code works only in cases where the grad_output has a shape of [B,C,1,1]. You should probably replace that with grad_bias=grad_output.sum(dim=(0,2,3)) for it to work properly for every Conv2d shape.

1 Like

Thanks! I was struggling with that issue actually.

Thanks to all the other responses. Here is the memory-efficient, fast solution:

from torch.utils.cpp_extension import load

cudnn_convolution = load(name="cudnn_convolution", sources=["cudnn_convolution.cpp"], verbose=True)


class CustomConv2d(Function):
    @staticmethod
    def forward(ctx, input, weight, bias, stride, padding, dilation, groups):
        ctx.save_for_backward(input, weight, bias)
        ctx.conf = {
            "stride": stride,
            "padding": padding,
            "dilation": dilation,
            "groups": groups
        }
        return cudnn_convolution.convolution(input, weight, bias, stride, padding, dilation, groups,
                                             False, False)

    @staticmethod
    def backward(ctx, grad_output):
        input, weight, bias = ctx.saved_variables
        conf = ctx.conf
        input_grad = weight_grad = bias_grad = stride_grad = padding_grad = dilation_grad = groups_grad =  None
        if ctx.needs_input_grad[0]:
            input_grad = cudnn_convolution.convolution_backward_input(input.shape, weight, grad_output, conf["stride"],
                                                                      conf["padding"], conf["dilation"], conf["groups"],
                                                                      False, False, False)
        if ctx.needs_input_grad[1]:
            weight_grad = cudnn_convolution.convolution_backward_weight(input, weight.shape, grad_output,
                                                                        conf["stride"], conf["padding"],
                                                                        conf["dilation"], conf["groups"],
                                                                        False, False, False)

        if bias is not None and ctx.needs_input_grad[2]:
            bias_grad = grad_output.sum(dim=(0, 2, 3))

        return input_grad, weight_grad, bias_grad, stride_grad, padding_grad, dilation_grad, groups_grad

For the above code to work, you need to create a file named “cudnn_convolution.cpp” which containts the following code(copied from this repo PyTorch cuDNN Convolution):

#include <torch/extension.h>
#include <vector>
#include <ATen/NativeFunctions.h>
#include <ATen/Config.h>

/*
PyTorch extension enabling direct access to the following cuDNN-accelerated C++ functions
that are included in PyTorch:
    - cudnn_convolution
    - cudnn_convolution_backward_weight
    - cudnn_convolution_backward_input
The functions defined here can be called from Python in replacement of
torch.nn.conv2d, torch.nn.grad.conv2d_weight and torch.nn.grad.conv2d_input,
and run significantly faster. See 'example.py' for how these functions
are called.
Adapted from code posted by hanspinckaers:
https://discuss.pytorch.org/t/cuda-error-with-cudnn-convolution-backward-weight-function/41214
*/

at::Tensor convolution(
    const at::Tensor& input,
    const at::Tensor& weight,
    const at::Tensor& bias,
    c10::ArrayRef<int64_t> stride,
    c10::ArrayRef<int64_t> padding,
    c10::ArrayRef<int64_t> dilation,
    int64_t groups,
    bool benchmark,
    bool deterministic) {

    return at::cudnn_convolution(
        input,
        weight,
        bias,
        padding,
        stride,
        dilation,
        groups,
        benchmark,
        deterministic); 
}

at::Tensor convolution_backward_weight(
    const at::Tensor& input,
    c10::ArrayRef<int64_t> weight_size,
    const at::Tensor& grad_output,
    c10::ArrayRef<int64_t> stride,
    c10::ArrayRef<int64_t> padding,
    c10::ArrayRef<int64_t> dilation,
    int64_t groups,
    bool benchmark,
    bool deterministic,
    bool allow_tf32) {

    return at::cudnn_convolution_backward_weight(
        weight_size,
        grad_output,
        input,
        padding,
        stride,
        dilation,
        groups,
        benchmark,
        deterministic,
        allow_tf32);
}

at::Tensor convolution_backward_input(
    c10::ArrayRef<int64_t> input_size,
    const at::Tensor& weight,
    const at::Tensor& grad_output,
    c10::ArrayRef<int64_t> stride,
    c10::ArrayRef<int64_t> padding,
    c10::ArrayRef<int64_t> dilation,
    int64_t groups,
    bool benchmark,
    bool deterministic,
    bool allow_tf32) {

    return at::cudnn_convolution_backward_input(
        input_size,
        grad_output,
        weight,
        padding,
        stride,
        dilation,
        groups,
        benchmark,
        deterministic,
        allow_tf32);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("convolution", &convolution, "convolution");
    m.def("convolution_backward_weight", &convolution_backward_weight, "convolution backward weight");
    m.def("convolution_backward_input", &convolution_backward_input, "convolution backward input");
}
3 Likes

Thanks for sharing this code.

I am trying to find the documentation or definition of this function at::cudnn_convolution. Is it available anywhere?

What is benchmark for?

If you are looking for the intuition behind using at::cudnn_convolution, you can look here. The exact definition of it may vary based on the version of cuDNN, for v7 look here, and for v8 here.

Here is a detailed explanation about benchmark.

1 Like