Is it possible to modify the low level convolution operation?

Hello! My project got some progress. I created my optimized depthwise convolution extension. The forward pass is working well, my forward pass time is faster than nn.conv2d() forward time.
But I am still stuck at the backward pass part.
I initially used pytorch native backward implementation in DepthwiseConv2d.cu, and it worked, but slow, compared to the backward pass time of nn.conv2d. Then I tried:

  1. cudnn convolution, but got error that cudnn_convolution_backward_weight and cudnn_convolution_backward_input are not members of ‘at’, because my pytorch is 1.12.
  2. then I found "at" has no member "cudnn_convolution_xxx", and tried at::convolution_backward(), but it seems that it is still slow compared to nn.conv2d().
  3. then I tried to use at::cudnn_convolution_backward at pytorch/ConvShared.cpp at main · pytorch/pytorch · GitHub, but got “not a member of ‘at’” error.

I am not sure how to set up the backward pass code with cudnn. Could you please tell me a little bit more about how to use cudnn convolution backward? Thank you!

And here is the code:

#include <torch/extension.h>
#include <vector>
#include <ATen/NativeFunctions.h>
#include <ATen/Functions.h>
#include <ATen/Config.h>

// CUDA backward declarations

std::tuple<at::Tensor, at::Tensor> optimizedDepthwise_backward(
    const at::Tensor& grad_output, 
    const at::Tensor& input, 
    const at::Tensor& weight,
    int64_t padding, 
    int64_t stride, 
    int64_t dilation, 
    int64_t groups,
    bool benchmark, 
    bool deterministic, 
    bool allow_tf32, 
    std::array<bool,2> output_mask);

std::tuple<at::Tensor, at::Tensor> optimizedDepthwise_backward(
  const at::Tensor& grad_output, 
  const at::Tensor& input, 
  const at::Tensor& weight,
  int64_t padding, 
  int64_t stride, 
  int64_t dilation, 
  int64_t groups,
  bool benchmark, 
  bool deterministic, 
  bool allow_tf32, 
  std::array<bool,2> output_mask){

    return at::cudnn_convolution_backward(
        input,
        grad_output, 
        weight,
        padding,
        stride,
        dilation,
        groups,
        benchmark,
        deterministic,
        output_mask);
}