I have been looking into PyTorch’s C++ CUDA/CPU Extension Tutorial and at PyTorch’s torch.jit.load(..)
function.
My understanding is that a CUDA kernel that is compiled for a specific shape will have a better performance than a generic CUDA kernel. (e.g. check the discussion in this link)
Having said that, my question is: it possible to create a CUDA kernel with shapes passed as a template (and hence it will be compiled for a specific shape), e.g., something like:
#include <torch/extension.h>
#include <vector>
#include <ATen/NativeFunctions.h>
#include <ATen/Config.h>
template<c10::ArrayRef<long int> weight_size, c10::ArrayRef<long int> padding,
c10::ArrayRef<long int> stride, c10::ArrayRef<long int> dilation>
at::Tensor conv2d_forward(
const at::Tensor& input,
const at::Tensor& weight,
int64_t groups,
bool benchmark,
bool deterministic) {
// TODO: define kernel here
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("conv2d_forward", &conv2d_forward, "Conv2d forward");
}
and then when JITTing pass the shape and related parameters:
from torch.utils.cpp_extension import load
conv2d_ours = load(name="conv2d_forward", sources=["conv2d_forward.cpp"], params=[ "weight_size" = weight_size, "padding"=padding, "stride"=stride, "dilation"=dilation], verbose=True)
Ideally, before invoking the convolution op, we look at the shape parameters and check if there has been a previously JITed version for the same parameters. If there exists, then we just invoked that pre-compiled kernel, else JIT compile the kernel and invoke it.
My understanding is that high-level compilers, like XLA and TVM, JIT compile kernels to specific shapes to enhance performance. It will be great if we can do something similar in PyTorch.