Shape of data affects speed of batch normalization

When looking for the reason for why the inference time of my network had suddenly increased by 25% compared to a previous version of itself, I realized that one of the changes I had made was to rearrange the layers and put each PixelShuffle layer before each corresponding BatchNorm2d layer instead of after it as it was before. Simply by making sure the batch normalization was done before the pixel shuffling I was able to reduce the inference time back to what it was before.

Since I doubt that the speed of the pixel shufflig depends on whether the data has been batch normalized or not (please correct me if I’m wrong), I supose that the speed of the batch normalization operation must depend on the shape of the data, which is affected by the pixel shuffling. Specifically, if the image data is reshaped to be higher and wider but have fewer feature maps, batch normalization of the data goes significantly slower.

Is this your experience too? Why is the speed of the batch normalization operation so dependent of the shape of the data, even though it contains equally many pixels in total? And can this knowledge be used to make the BatchNorm2d operation go even quicker?

Yes, this might be expected as even though “equally many pixels” are used on both approaches, the reductions and elementwise operations would not be the same.
I don’t know which backend you are using or if you are on the GPU, but you might want to take a look at this nvFuser tutorial which explains how PyTorch uses it to automatically generate fast kernels for these kind of operations (normalizations are already working) to avoid these performance cliffs.

What do you mean by which backend I’m using? Yes, I’m working with the data on the GPU. Thanks for the link; I will take a look!