Is it possible to modify the low level convolution operation?

I’m new to PyTorch low level stuff.

  1. When using nn.Conv2d, I guess, there should be some low level CUDA code for doing the real convolution operations, right? Where to find them in the PyTorch or somewhere else?
  2. I have implemented some optimized convolution operations in CUDA (C++) and I am trying to replace my code into the original PyTorch and recompile the PyTorch, if this is possible, what is the correct step by step workflow for this?

Thank you for your time and your help would be greatly appreciated!

  1. It depends on the device and thus backend you are using. E.g. if you are using a GPU you could find the native convolution in aten/stc/ATen/native/cuda/ConvolutionMM2d.cu and could adapt it. cuDNN is used by default for better performance, but since it’s closed source you won’t be able to take a look at their code and manipulate it.

  2. Either change the definition in the linked file or create a custom extension, which might be easier.

Thank you for your reply!
Yes, I am using a GPU, so I looked through that linked file (ConvolutionMM2d.cu).

  1. In it, I found slow_conv2d_forward function which contains at::cuda::blas::gemm at the end. I think this is the place where convolution is done by CUDNN, right?
  2. I implemented my fast convolution code by following the standard Visual Studio CUDA template (Prepare input data, prepare device memory, copy to device, invoke convolution operation kernel, then copy data back, etc.).
    And I think my kernel functions (__ global __ void fasterConv) are correct and they are controlling the GPU hardware at a low level.
    So, could you please elaborate the way to merge my kernel code into PyTorch to achieve faster convolution? In my case, is this the correct tutorial for creating a custom extension? Or maybe I can just copy my kernel code to somewhere and fix the type errors and other minor errors.

My questions are pretty long, thank you for your patience and your time!

  1. No, the matmul approach via cublas is the native implementation in PyTorch and cuDNN is closed source as described before. The cuDNN calls are performed here for the v7 API and here for the newer v8 API.

  2. Yes, a custom extension should work.

Great! Thank you for your reply!
More accurately, my project is optimizing depthwise and pointwise convolution operations for mobilenets. So, my question should be: how are these operations implemented in the original pytorch in the backend? Where are the code for them? Can I replace my depthwise and pointwise convolution kernels to those places?
I had a look at the link for custom extension. It’s kind of creating a new module, which needs forward pass and backward pass. But for now, I only implemented the optimized forward pass code. Is there any existed code for the backward pass for depthwise and pointwise convolution that I can use directly? or maybe other easy ways to implement the backward pass for the custom extension module?
Thank you so much!

As explained before, PyTorch uses native (slow) kernels as well as optimized kernels from e.g. cuDNN (and other libs for other backends). You can use the first posted link to check for different conv implementations, e.g. DepthwiseConv2d.cu should contain the native (slow) code for depthwise convs.

Yes, you should be able to find the backward in the linked code and could reuse it.