JIT Compile with Shape Parameters Passed as Templates

I have been looking into PyTorch’s C++ CUDA/CPU Extension Tutorial and at PyTorch’s torch.jit.load(..) function.

My understanding is that a CUDA kernel that is compiled for a specific shape will have a better performance than a generic CUDA kernel. (e.g. check the discussion in this link)

Having said that, my question is: it possible to create a CUDA kernel with shapes passed as a template (and hence it will be compiled for a specific shape), e.g., something like:

#include <torch/extension.h>

#include <vector>
#include <ATen/NativeFunctions.h>
#include <ATen/Config.h>

template<c10::ArrayRef<long int> weight_size,  c10::ArrayRef<long int> padding,
    c10::ArrayRef<long int> stride, c10::ArrayRef<long int> dilation>
at::Tensor conv2d_forward(
    const at::Tensor& input,
    const at::Tensor& weight,
    int64_t groups,
    bool benchmark,
    bool deterministic) {

  // TODO: define kernel here

  m.def("conv2d_forward", &conv2d_forward, "Conv2d forward");

and then when JITTing pass the shape and related parameters:

from torch.utils.cpp_extension import load
conv2d_ours = load(name="conv2d_forward", sources=["conv2d_forward.cpp"], params=[ "weight_size" = weight_size, "padding"=padding, "stride"=stride, "dilation"=dilation], verbose=True)

Ideally, before invoking the convolution op, we look at the shape parameters and check if there has been a previously JITed version for the same parameters. If there exists, then we just invoked that pre-compiled kernel, else JIT compile the kernel and invoke it.

My understanding is that high-level compilers, like XLA and TVM, JIT compile kernels to specific shapes to enhance performance. It will be great if we can do something similar in PyTorch.