Slow convolutions in triton

When using torch.compile() on torch.vision models on the GPU, the triton version is normally (densenet and alexnet are slower) faster than the CUDA version. However, when testing single convolutions (torch conv2d), the triton version is always ~20% slower. Compilation overhead is accounted for, the median of multiple runs is taken. I have replicated convolutional parameters from the source code of the vision models, as well as run an ablation on kernel size, stride, and padding. In and out channels are fixed at 3 and 64 respectively, as is the case in all vision models.

When profiled with the torch profiler, it appears to have great speedups, but wallclock time shows them to be much slower.

Any ideas on why this seems to be the case?

1 Like

For 1x1 convolutions, I have consistently observed 2x-3x higher latency than baseline PyTorch. I would also like to know the reason for this specifically.
It seems in general that compiled Triton is not able to handle convolutions very well (multiple sources on Triton Github issues)