On both PyTorch 2.3.1+cu121 (Colab A100) and 2.4.0+cu124 (H100), torch.compile is introducing unnecessary memory layout change kernels around Conv3d layers. These Conv3d layers have 3x3x3 kernel size, 1x1x1 padding, no bias, no activation and preserve the channel dimension.
class ConvNet(nn.Module):
"""Simple convolutional network."""
def __init__(self, in_channels: int, hidden_channels: int, out_channels: int):
super().__init__()
self.conv_1 = nn.Conv3d(in_channels, hidden_channels, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
self.conv_2 = nn.Conv3d(hidden_channels, hidden_channels, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
self.conv_3 = nn.Conv3d(hidden_channels, out_channels, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass.
Args:
x: Input 3D tensor. Shape: [B, C, T, H, W].
"""
assert x.ndim == 5
with record_function("conv_layers"):
x = self.conv_1(x)
x = self.conv_2(x)
x = self.conv_3(x)
assert x.is_contiguous(memory_format=torch.channels_last_3d)
return x
Without compilation, the stack takes 6.482 ms / step after warmup on an input with shape (4, 8, 64, 256, 256) and 64 channels in the network.
When torch.compiled, the same network takes 10.878 ms / step.
The PyTorch compiler is introducing an unnecessary nhwcToNchwKernel immediately before an nchwToNhwcKernel between every pair of Conv3d layers:
This leads to a 1.68x slowdown on A100.
Minimal reproducible example. The input tensor and the network are in channels_last_3d format. Google Colab
Profiles and code are here: Conv3d bug - Google Drive