Cudnn Error with cudnn convolution backward function

Hi, everyone,
When I writing cpp extension using cudnn_convolution_backward, I meet Cudnn Error. My cpp extension looks like this:

#include <torch/extension.h>
#include <ATen/NativeFunctions.h>
#include <ATen/Config.h>
#include <cstdio>
#include <iostream>
#include <array>


at::Tensor conv2d(
    const at::Tensor& input,
    const at::Tensor& weight,
    const at::Tensor& bias,
    int64_t stride_weight,
    int64_t stride_height,
    int64_t padding_weight,
    int64_t padding_height,
    int64_t dilation_weight,
    int64_t dilation_height,
    int64_t groups,
    bool benchmark,
    bool deterministic) {

    std::array<int64_t, 2> arr_stride{stride_weight, stride_height};
    std::array<int64_t, 2> arr_padding{padding_weight, padding_height};
    std::array<int64_t, 2> arr_dilation{dilation_weight, dilation_height};

    c10::IntArrayRef stride(arr_stride);
    c10::IntArrayRef padding(arr_padding);
    c10::IntArrayRef dilation(arr_dilation);

    // printf("--------------------------------hello conv2d----------------------------------------\n");

    return at::cudnn_convolution(
        input,
        weight,
        bias,
        padding,
        stride,
        dilation,
        groups,
        benchmark,
        deterministic);
}

std::tuple<at::Tensor, at::Tensor, at::Tensor> conv2d_backward(
    const at::Tensor& self,
    const at::Tensor& grad_output,
    const at::Tensor& weight,
    int64_t stride_width,
    int64_t stride_height,
    int64_t padding_width,
    int64_t padding_height,
    int64_t dilation_width,
    int64_t dilation_height,
    int64_t groups,
    bool benchmark,
    bool deterministic,
    std::array<bool, 3> output_mask) {

    // printf("--------------------------------hello conv2d backward----------------------------------------\n");
    
    std::array<int64_t, 2> arr_stride{stride_width, stride_height};
    std::array<int64_t, 2> arr_padding{padding_width, padding_height};
    std::array<int64_t, 2> arr_dilation{dilation_width, dilation_height};

    c10::IntArrayRef stride(arr_stride);
    c10::IntArrayRef padding(arr_padding);
    c10::IntArrayRef dilation(arr_dilation);

    return at::cudnn_convolution_backward(
        self,
        grad_output,
        weight,
        padding,
        stride,
        dilation,
        groups,
        benchmark,
        deterministic,
        output_mask);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("conv2d", &conv2d, "2d convolution");
    m.def("conv2d_backward", &conv2d_backward, "2d convolution backward");
}

I test extensions like this:

import torch
import torch.nn.functional as F
from torch.autograd import Function
import math
import cudnn_conv2d
from torch.utils.cpp_extension import load

cudnn_convolution = load(name="cudnn_convolution", sources=["cudnn_convolution.cpp"], verbose=True)

input  = torch.zeros(1, 1, 4, 4).to('cuda')
weight = torch.zeros(1, 1, 3, 3).to('cuda')
bias   = torch.zeros(1).to('cuda')

stride   = (1, 1)
padding  = (0, 0)
dilation = (1, 1)
groups   = 1

# # compute the result of convolution
output = cudnn_convolution.conv2d(input, weight, bias, stride[0], stride[1], padding[0], padding[1], dilation[0], dilation[1], groups, False, False)

grad_output = torch.zeros(1, 1, 2, 2).to('cuda')

grad_i, grad_w = cudnn_convolution.conv2d_backward(input, grad_output, weight, padding[0], padding[1], stride[0], stride[1], dilation[0], dilation[1], groups, False, False, [True, True, False])

print(grad_weight.shape)
print(grad_input.shape)

I receive the following error:

Loading extension module cudnn_convolution...
Traceback (most recent call last):
  File "example.py", line 30, in <module>
    grad_i, grad_w = cudnn_convolution.conv2d_backward(input, grad_output, weight, padding[0], padding[1], stride[0], stride[1], dilation[0], dilation[1], groups, False, False, [True, True, False])
RuntimeError: cuDNN error: CUDNN_STATUS_BAD_PARAM (set at /pytorch/aten/src/ATen/cudnn/Descriptors.h:157)
frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x33 (0x7ff9e40fe813 in /home/yangkan/.local/lib/python3.6/site-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0x4151585 (0x7ff983527585 in /home/yangkan/.local/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #2: <unknown function> + 0x414a4b5 (0x7ff9835204b5 in /home/yangkan/.local/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #3: <unknown function> + 0x414b5bb (0x7ff9835215bb in /home/yangkan/.local/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #4: at::native::cudnn_convolution_backward_input(c10::ArrayRef<long>, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool) + 0xf7 (0x7ff983521ba7 in /home/yangkan/.local/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #5: <unknown function> + 0x41a9805 (0x7ff98357f805 in /home/yangkan/.local/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #6: <unknown function> + 0x3918558 (0x7ff982cee558 in /home/yangkan/.local/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #7: at::native::cudnn_convolution_backward(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, std::array<bool, 3ul>) + 0x579 (0x7ff98351da99 in /home/yangkan/.local/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #8: <unknown function> + 0x41a9b0c (0x7ff98357fb0c in /home/yangkan/.local/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #9: <unknown function> + 0x3918505 (0x7ff982cee505 in /home/yangkan/.local/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #10: <unknown function> + 0x3bd6e11 (0x7ff982face11 in /home/yangkan/.local/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #11: <unknown function> + 0x3918505 (0x7ff982cee505 in /home/yangkan/.local/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #12: std::tuple<at::Tensor, at::Tensor, at::Tensor> c10::KernelFunction::callUnboxedOnly<std::tuple<at::Tensor, at::Tensor, at::Tensor>, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, std::array<bool, 3ul> >(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, std::array<bool, 3ul>) const + 0x1ac (0x7ff977ad1d74 in /tmp/torch_extensions/cudnn_convolution/cudnn_convolution.so)
frame #13: c10::impl::OperatorEntry::callUnboxedOnly<std::tuple<at::Tensor, at::Tensor, at::Tensor>, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, std::array<bool, 3ul> >(c10::TensorTypeId, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, std::array<bool, 3ul>) const::{lambda(c10::DispatchTable const&)#1}::operator()(c10::DispatchTable const&) const + 0x162 (0x7ff977acdd52 in /tmp/torch_extensions/cudnn_convolution/cudnn_convolution.so)
frame #14: std::result_of<c10::impl::OperatorEntry::callUnboxedOnly<std::tuple<at::Tensor, at::Tensor, at::Tensor>, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, std::array<bool, 3ul> >(c10::TensorTypeId, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, std::array<bool, 3ul>) const::{lambda(c10::DispatchTable const&)#1} (c10::DispatchTable const&)>::type c10::LeftRight<c10::DispatchTable>::read<c10::impl::OperatorEntry::callUnboxedOnly<std::tuple<at::Tensor, at::Tensor, at::Tensor>, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, std::array<bool, 3ul> >(c10::TensorTypeId, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, std::array<bool, 3ul>) const::{lambda(c10::DispatchTable const&)#1}>(c10::impl::OperatorEntry::callUnboxedOnly<std::tuple<at::Tensor, at::Tensor, at::Tensor>, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, std::array<bool, 3ul> >(c10::TensorTypeId, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, std::array<bool, 3ul>) const::{lambda(c10::DispatchTable const&)#1}&&) const + 0x128 (0x7ff977ad1fdc in /tmp/torch_extensions/cudnn_convolution/cudnn_convolution.so)
frame #15: std::tuple<at::Tensor, at::Tensor, at::Tensor> c10::impl::OperatorEntry::callUnboxedOnly<std::tuple<at::Tensor, at::Tensor, at::Tensor>, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, std::array<bool, 3ul> >(c10::TensorTypeId, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, std::array<bool, 3ul>) const + 0xc6 (0x7ff977acde44 in /tmp/torch_extensions/cudnn_convolution/cudnn_convolution.so)
frame #16: std::tuple<at::Tensor, at::Tensor, at::Tensor> c10::Dispatcher::callUnboxedOnly<std::tuple<at::Tensor, at::Tensor, at::Tensor>, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, std::array<bool, 3ul> >(c10::OperatorHandle const&, c10::TensorTypeId, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, std::array<bool, 3ul>) const + 0x188 (0x7ff977ac9f1c in /tmp/torch_extensions/cudnn_convolution/cudnn_convolution.so)
frame #17: <unknown function> + 0x2c1be (0x7ff977abb1be in /tmp/torch_extensions/cudnn_convolution/cudnn_convolution.so)
frame #18: conv2d_backward(at::Tensor const&, at::Tensor const&, at::Tensor const&, long, long, long, long, long, long, long, bool, bool, std::array<bool, 3ul>) + 0x133 (0x7ff977abb524 in /tmp/torch_extensions/cudnn_convolution/cudnn_convolution.so)
frame #19: <unknown function> + 0x4f085 (0x7ff977ade085 in /tmp/torch_extensions/cudnn_convolution/cudnn_convolution.so)
frame #20: <unknown function> + 0x4b41d (0x7ff977ada41d in /tmp/torch_extensions/cudnn_convolution/cudnn_convolution.so)
frame #21: <unknown function> + 0x46fbb (0x7ff977ad5fbb in /tmp/torch_extensions/cudnn_convolution/cudnn_convolution.so)
frame #22: <unknown function> + 0x473d4 (0x7ff977ad63d4 in /tmp/torch_extensions/cudnn_convolution/cudnn_convolution.so)
frame #23: <unknown function> + 0x37b97 (0x7ff977ac6b97 in /tmp/torch_extensions/cudnn_convolution/cudnn_convolution.so)
frame #24: python3() [0x50ac25]
<omitting python frames>
frame #26: python3() [0x508245]
frame #28: python3() [0x635222]
frame #33: __libc_start_main + 0xe7 (0x7ff9ea18cb97 in /lib/x86_64-linux-gnu/libc.so.6)

I can not find bug in the code. Does anyboby know how to debug?

cudnn_convolution function is correct. I can use cudnn_convolution function to compute the result of convolution. Error occur in the backward.

There are a couple of issues in the code:

  • the CUDNN_STATUS_BAD_PARAM error is raised, since you are passing padding[0] and padding[1] before the strides, while your conv2d_backward expects the opposite

  • You are returning an std::tuple of three at::Tensors, while at::cudnn_convolution_backward returns two

  • output_mask is defined as std::array<bool,2>, while you are passing three bool values

1 Like

Thanks, when I pass padding[0] ans padding[1] after the stride. The program is correct.

1 Like

I use cudnn_convolution_backward in ATen/NativeFunctions.h, that return an std::tuple of three at::Tensors, output_mask is defined as std::array<bool, 3>.
Is the version of my pytorch different from yours? The version of pytorch is 1.4.

Not sure, but I used the definition in cudnn/Conv.cpp, which seemed to work.
If your code is working fine, ignore my comment. :wink:

My code is flawed.When I use the code, I must pass bias, So I don’t set bias=False.I think your comment is perfect to solve this problem.I reconfirm the version of pytorch is 1.3.1. But I can not find ATen/native/cudnn/Conv.cpp in the installation path of pytorch(I use pip3 install torch torchvision --user). How do I solve this problem. Meanwhile, I meet a new Error, RuntimeError: Expected tensor’s dynamic type to be Variable, not Tensor . I am confused and looking for a solution.@ptrblck

Why is it flawed? It seems your other thread was answered already.