Why does batch norm slow down when Conv2D is executed in unfolded fashion?

I created a Conv2d layer that uses unfolding followed by an MVM. I then combine it with a BatchNorm operation in a Sequential model. I do the same but this time with a normal Conv2d layer. I then profile both and compare the outputs.
I see that the batch norm call aten::batch_norm takes 3.5x longer with the unfolded convolution. I put everything on Cuda. Here is a small snippet to reproduce:
Why am I seeing this slow-down in the batch norm? Is there some fusing going on internally?

import torch

class ConvUnfold(torch.nn.Conv2d):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size,
        bias,
        device,
    ):
        super().__init__(
            in_channels,
            out_channels,
            kernel_size,
            bias=bias,
            device=device,
        )
        self.linear_weight = self.weight.reshape(shape=(out_channels, in_channels*kernel_size**2))
        self.linear_weight = self.linear_weight.to(device)

    def _mvm(self, input):
        return input @ self.linear_weight.T

    def _forward_unfold(self, x_input):
        im_shape = x_input.shape
        x_input_ = torch.nn.functional.unfold(x_input, kernel_size=self.kernel_size, dilation=self.dilation,
                          padding=self.padding, stride=self.stride).transpose(1, 2)
        out = self._mvm(x_input_).transpose(1, 2)
        out_size = (im_shape[2] + 2 * self.padding[0]
                    - self.dilation[0] * (self.kernel_size[0] - 1) - 1) // self.stride[0] + 1
        return out.view(im_shape[0], self.out_channels, out_size, -1)

    def forward(self, input):
        out = self._forward_unfold(input)
        return out


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    input = torch.randn(size=(128,3,32,32)).to(device)
    unf = torch.nn.Sequential(ConvUnfold(3, 16, 3, bias=False, device=device), torch.nn.BatchNorm2d(16))
    conv = torch.nn.Sequential(torch.nn.Conv2d(3, 16, 3, bias=False), torch.nn.BatchNorm2d(16))
    unf.eval(); conv.eval()

    unf.to(device); conv.to(device)
    conv.load_state_dict(unf.state_dict())

    from torch.profiler import profile, record_function, ProfilerActivity
    with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof_unf:
        with record_function("model_inference"):
            unf(input)

    with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof_conv:
        with record_function("model_inference"):
            conv(input)

    print(prof_unf.key_averages().table(sort_by="cuda_time_total", row_limit=10))
    print(prof_conv.key_averages().table(sort_by="cuda_time_total", row_limit=10))

Could you explain a bit more what exactly is reporting these stats and that the actual kernel execution time is reported?
Are you seeing a kernel slowdown using e.g nsys nvprof python script.py args for the batchnorm layer?

Hi.

When I run the script with the torch profiler and look at the respective execution times, I see a difference for the batch norm operation time.
I don’t know if this is an artifact of something else, but I would like to understand it.
If you execute the script, it should print both profiles showing a slower operation time for batch_norm if we use the original conv operation.

I haven’t tried nsys and nvprof. I thought the torch profiler should abstract that away.

Thanks for the help.

I’m unsure what exactly should be abstracted away.

In any case, I’m seeing almost the same execution time with a minor difference as it seems you are changing the memory layout from the default NCHW to NHWC:

CUDA Kernel Statistics:

 Time (%)  Total Time (ns)  Instances  Avg (ns)  Med (ns)  Min (ns)  Max (ns)  StdDev (ns)                                                  Name                                                
 --------  ---------------  ---------  --------  --------  --------  --------  -----------  ----------------------------------------------------------------------------------------------------
      2.8           24,800          1  24,800.0  24,800.0    24,800    24,800          0.0  void cudnn::bn_fw_inf_1C11_kernel_NHWC<float, float, (bool)1, (bool)1>(T2, T2, cudnnTensorStruct, c…
      2.4           21,312          1  21,312.0  21,312.0    21,312    21,312          0.0  void cudnn::bn_fw_inf_1C11_kernel_NCHW<float, float, (bool)1, (int)1>(T2, T2, cudnnTensorStruct, co…

which could explain the difference in the runtime.

Ok thanks. Must be that I was using the tool incorrectly.