Unexplained overhead when using custom CUDA kernel in torch.compile reduce-overhead mode

Hi,

I am working on inference optimization of a diffusion model for pedagogical purposes. Since the model uses lots of small kernels and inference is with batch-size=1, pytorch overhead is huuuge. I am trying to use CUDA graphs to eliminate this overhead, as well as integrate a custom CUDA kernel for one of the blocks (1D Conv + group norm + mish). When using torch.compile in reduce-overhead mode, I am finding that the standard pytorch is able to more or less fully graph the network and all kernels run back to back in a single GPU stream. When using a custom kernel for the 1D conv + GN + M block, alignment of pytorch kernels and the kernels between streams seems off, and there is a huge amount of idle-time, where no activity is shown in trace on CPU or GPU. Initially I thought this was CUDA stream related, but even after passing the current CUDA stream with the ATen C++ API (getCurrentCUDAStream()) I see this behavior. The CUDA graph also seems to break into many pieces with the custom kernel.

Any advice on how to fit all pytorch + custom kernel execution for my network into a single kernel? I see my custom block is 3x faster than pytorch on its own, so if it could integrate into the network more organically I think I would get much better performance. Included trace and profiler stats for both custom & pytorch launches to this post.

Thanks in advance!

Standard Pytorch Graphed Run-time, according to profiler:
Self CPU time total: 1.953ms
Self CUDA time total: 1.676ms
https://i.imgur.com/UI4vMh7.png

Custom Kernel Graphed Run-time:
Self CPU time total: 754.000us
Self CUDA time total: 1.226ms
https://imgur.com/a/ggT566o

Notice in the traces that despite lower CPU/CUDA time, overall run-time with custom kernel is longer due to idle time.

I was able to get a single graph using the standard stream capture API, I tried this a few days ago but had the torch.compile(mode=‘reduce-overhead’) decorator on so got some strange errors. So my original issue is fixed.

But if anyone has tips on making custom CUDA kernels play better with compile & reduce overhead mode I would love to hear them!