Compiling stack of Conv3d increases runtime with redundant format conversions

On both PyTorch 2.3.1+cu121 (Colab A100) and 2.4.0+cu124 (H100), torch.compile is introducing unnecessary memory layout change kernels around Conv3d layers. These Conv3d layers have 3x3x3 kernel size, 1x1x1 padding, no bias, no activation and preserve the channel dimension.

class ConvNet(nn.Module):
    """Simple convolutional network."""
    def __init__(self, in_channels: int, hidden_channels: int, out_channels: int):
        super().__init__()
        self.conv_1 = nn.Conv3d(in_channels, hidden_channels, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        self.conv_2 = nn.Conv3d(hidden_channels, hidden_channels, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        self.conv_3 = nn.Conv3d(hidden_channels, out_channels, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)

    def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward pass.
        Args:
            x: Input 3D tensor. Shape: [B, C, T, H, W].
        """
        assert x.ndim == 5

        with record_function("conv_layers"):
            x = self.conv_1(x)
            x = self.conv_2(x)
            x = self.conv_3(x)

        assert x.is_contiguous(memory_format=torch.channels_last_3d)

        return x

Without compilation, the stack takes 6.482 ms / step after warmup on an input with shape (4, 8, 64, 256, 256) and 64 channels in the network.

When torch.compiled, the same network takes 10.878 ms / step.

The PyTorch compiler is introducing an unnecessary nhwcToNchwKernel immediately before an nchwToNhwcKernel between every pair of Conv3d layers:

This leads to a 1.68x slowdown on A100.

Minimal reproducible example. The input tensor and the network are in channels_last_3d format. Google Colab

Profiles and code are here: Conv3d bug - Google Drive

CC @marksaroufim and @eqy - would you know if torch.compile tries to use the channels-last format by default for these layers and adds these transformations?

@eellison is the lead for inductor and would likely know

Yes, we do attempt to use channels last by default depending on the convs seen. You can disable with TORCHINDUCTOR_LAYOUT_OPTIMIZATION=0 or torch._inductor.config.layout_optimization = False.

Maybe we need to be less agressive with conv3d. Would you mind filing issue ?

Separately - is there a reason you are using autocast instead of converting parameters to bfloat16 before running model ?

I believe I’ve been forwarded this issue previously and have passed it on to cuDNN. To comment I’m not sure compile is really to blame as even in eager mode we observed the format conversions