Is it possible to modify the low level convolution operation?

I’m new to PyTorch low level stuff.

  1. When using nn.Conv2d, I guess, there should be some low level CUDA code for doing the real convolution operations, right? Where to find them in the PyTorch or somewhere else?
  2. I have implemented some optimized convolution operations in CUDA (C++) and I am trying to replace my code into the original PyTorch and recompile the PyTorch, if this is possible, what is the correct step by step workflow for this?

Thank you for your time and your help would be greatly appreciated!

  1. It depends on the device and thus backend you are using. E.g. if you are using a GPU you could find the native convolution in aten/stc/ATen/native/cuda/ConvolutionMM2d.cu and could adapt it. cuDNN is used by default for better performance, but since it’s closed source you won’t be able to take a look at their code and manipulate it.

  2. Either change the definition in the linked file or create a custom extension, which might be easier.

Thank you for your reply!
Yes, I am using a GPU, so I looked through that linked file (ConvolutionMM2d.cu).

  1. In it, I found slow_conv2d_forward function which contains at::cuda::blas::gemm at the end. I think this is the place where convolution is done by CUDNN, right?
  2. I implemented my fast convolution code by following the standard Visual Studio CUDA template (Prepare input data, prepare device memory, copy to device, invoke convolution operation kernel, then copy data back, etc.).
    And I think my kernel functions (__ global __ void fasterConv) are correct and they are controlling the GPU hardware at a low level.
    So, could you please elaborate the way to merge my kernel code into PyTorch to achieve faster convolution? In my case, is this the correct tutorial for creating a custom extension? Or maybe I can just copy my kernel code to somewhere and fix the type errors and other minor errors.

My questions are pretty long, thank you for your patience and your time!

  1. No, the matmul approach via cublas is the native implementation in PyTorch and cuDNN is closed source as described before. The cuDNN calls are performed here for the v7 API and here for the newer v8 API.

  2. Yes, a custom extension should work.

Great! Thank you for your reply!
More accurately, my project is optimizing depthwise and pointwise convolution operations for mobilenets. So, my question should be: how are these operations implemented in the original pytorch in the backend? Where are the code for them? Can I replace my depthwise and pointwise convolution kernels to those places?
I had a look at the link for custom extension. It’s kind of creating a new module, which needs forward pass and backward pass. But for now, I only implemented the optimized forward pass code. Is there any existed code for the backward pass for depthwise and pointwise convolution that I can use directly? or maybe other easy ways to implement the backward pass for the custom extension module?
Thank you so much!

As explained before, PyTorch uses native (slow) kernels as well as optimized kernels from e.g. cuDNN (and other libs for other backends). You can use the first posted link to check for different conv implementations, e.g. DepthwiseConv2d.cu should contain the native (slow) code for depthwise convs.

Yes, you should be able to find the backward in the linked code and could reuse it.

Hello! My project got some progress. I created my optimized depthwise convolution extension. The forward pass is working well, my forward pass time is faster than nn.conv2d() forward time.
But I am still stuck at the backward pass part.
I initially used pytorch native backward implementation in DepthwiseConv2d.cu, and it worked, but slow, compared to the backward pass time of nn.conv2d. Then I tried:

  1. cudnn convolution, but got error that cudnn_convolution_backward_weight and cudnn_convolution_backward_input are not members of ‘at’, because my pytorch is 1.12.
  2. then I found "at" has no member "cudnn_convolution_xxx", and tried at::convolution_backward(), but it seems that it is still slow compared to nn.conv2d().
  3. then I tried to use at::cudnn_convolution_backward at pytorch/ConvShared.cpp at main · pytorch/pytorch · GitHub, but got “not a member of ‘at’” error.

I am not sure how to set up the backward pass code with cudnn. Could you please tell me a little bit more about how to use cudnn convolution backward? Thank you!

And here is the code:

#include <torch/extension.h>
#include <vector>
#include <ATen/NativeFunctions.h>
#include <ATen/Functions.h>
#include <ATen/Config.h>

// CUDA backward declarations

std::tuple<at::Tensor, at::Tensor> optimizedDepthwise_backward(
    const at::Tensor& grad_output, 
    const at::Tensor& input, 
    const at::Tensor& weight,
    int64_t padding, 
    int64_t stride, 
    int64_t dilation, 
    int64_t groups,
    bool benchmark, 
    bool deterministic, 
    bool allow_tf32, 
    std::array<bool,2> output_mask);

std::tuple<at::Tensor, at::Tensor> optimizedDepthwise_backward(
  const at::Tensor& grad_output, 
  const at::Tensor& input, 
  const at::Tensor& weight,
  int64_t padding, 
  int64_t stride, 
  int64_t dilation, 
  int64_t groups,
  bool benchmark, 
  bool deterministic, 
  bool allow_tf32, 
  std::array<bool,2> output_mask){

    return at::cudnn_convolution_backward(
        input,
        grad_output, 
        weight,
        padding,
        stride,
        dilation,
        groups,
        benchmark,
        deterministic,
        output_mask);
}

at::convolution_backward should be the right approach.

Thank you! I see… Then, I guess, my convolution backward pass should have comparable performance to nn.conv2d, right? But that’s not the case for me. Do you have any hint on this? My forward kernels have no effect on the backward pass. And I do have CUDA compatible GPU on the backend, is it possible that for some reason, at::convolution_backward didn’t use cuDNN? How can I check this?

It might be possible that native kernels are used and you could check it by profiling your actual workload to see if cuDNN is used or not. To do so you could use the built-in profiler or e.g. Nsight Systems.

Thank you for your help! I used PyTorch profiler to profile both of my convolution module and the nn.conv2d.

It shows that, during the .backward():

  • nn.conv2d called at::native::(anonymous namespace)::conv_depthwi... once

  • my module called above function twice.

I think the reason is that my module is calcualting “input_grad” and “weight_grad” as shown here:

def backward(ctx, grad_output):
        input, filter = ctx.saved_tensors
        conf = ctx.conf
        output = optimizedDepthwise_cuda.backward(grad_output, input, filter, conf["stride"],
conf["padding"], conf["dilation"], False, 0, conf["groups"], [True, True, False])
        input_grad, weight_grad, _ = output
        return input_grad, weight_grad, None, None, None, None, None

So my new questions are:

  • Does nn.conv2d calculate input_grad or weight_grad? or both?
  • Should I calculate both of them for the correct backward pass?

I just followed some tutorials to finish the backward call, and they calculated both input_grad and weight_grad, but I’m not sure if they are correct.

Thank you so much! Your help is highly appreciated!!!

Seems that this should be the correct solution:

To implement my optimized module in python:

class optimizedDepthwiseFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, filter, filterHeight, stride, padding, dilation, groups):
        ctx.save_for_backward(input, filter)
        ctx.conf = {
            "filterHeight": filterHeight,
            "stride": stride,
            "padding": padding,
            "dilation": dilation,
            "groups": groups
        }

        output = optimizedDepthwise_cuda.forward(input, filter, filterHeight, stride)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, filter = ctx.saved_tensors

        conf = ctx.conf
        grad_input = grad_weight = None

        if ctx.needs_input_grad[0]:
            input_ = grad_output.new_empty(1).expand(input.shape)
            grad_input = torch.ops.aten.convolution_backward(grad_output, input_, filter, None,
                                               (conf["stride"], conf["stride"]), (conf["padding"], conf["padding"]), (conf["dilation"], conf["dilation"]),
                                               False, [0], conf["groups"], (True, False, False))[0]
        
        if ctx.needs_input_grad[1]:
            filter_ = grad_output.new_empty(1).expand(filter.shape)
            grad_weight = torch.ops.aten.convolution_backward(grad_output, input, filter_, None,
                                               (conf["stride"], conf["stride"]), (conf["padding"], conf["padding"]), (conf["dilation"], conf["dilation"]),
                                               False, [0], conf["groups"], (False, True, False))[1]
        
        return grad_input, grad_weight, None, None, None, None, None
class optimizedDepthwiseLayer(torch.nn.Module):
    def __init__(self, inputChannel, outputChannel, filterHeight, stride):
        super(optimizedDepthwiseLayer, self).__init__()
        self.inputChannel = inputChannel
        self.outputChannel = outputChannel
        self.filterHeight = filterHeight
        self.stride = stride
        if(self.filterHeight == 3):
            self.padding = 1
        elif(self.filterHeight == 5):
            self.padding = 2
        self.dilation = 1
        self.groups = inputChannel
        
        self.filter = torch.nn.Parameter(torch.empty((self.inputChannel, 1, self.filterHeight, self.filterHeight), dtype=torch.float))
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1.0 / math.sqrt(self.inputChannel * self.filterHeight * self.filterHeight)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, +stdv)

    def forward(self, input):
        return optimizedDepthwiseFunction.apply(
            input, 
            self.filter, 
            self.filterHeight,
            self.stride,
            self.padding,
            self.dilation,
            self.groups)

And I acutually do not need at::convolution_backward in the cpp file, because I can directly use torch.ops.aten.convolution_backward in my python module.

#include <torch/extension.h>
#include <vector>
#include <array>

#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

torch::Tensor optimizedDepthwise_cuda_forward(
  torch::Tensor input, 
  torch::Tensor filter,
  int filterHeight,
  int stride);

torch::Tensor optimizedDepthwise_forward(
    torch::Tensor input,
    torch::Tensor filter,
    int filterHeight,
    int stride) {
    
    CHECK_INPUT(input);
    CHECK_INPUT(filter);

    return optimizedDepthwise_cuda_forward(
      input,
      filter,
      filterHeight,
      stride);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("forward", &optimizedDepthwise_forward, "Optimized Depthwise forward (CUDA)");
}

Could you please help me to double check if my implementation logic contains any mistake? Thank you!

Your code looks good. One minor suggestion: don’t use the deprecated .data attribute but initialize the parameter in a with torch.no_grad() guard using methods from the torch.nn.init namespace.

Thank you! It worked! The layers can be placed into the model and it is running correctly!
But now, I got another problem. I am trying to run my code on a ROCm device, and it seems that my backward part of the custom extension is wrong. It’s not calculating correctly…
Could you please tell me which one of the following options is the right one? for PyTorch 1.10+rocm

  1. in the python backward static function: call torch.ops.aten.miopen_depthwise_convolution_backward (I thought this is for depthwise convolution backward)
  2. in the python backward static function: call torch.ops.aten.miopen_convolution_backward
  3. in the extension cpp file, use aten::miopen_depthwise_convolution_backward_weight and aten::miopen_depthwise_convolution_backward_input.

Is there a correct method for doing backward on a rocm device?

Sorry, I’m not familiar with ROCm and MIOpen, but the function names look correct (1 and 3).