Support for Batch Renormalization

For the past few days, I’ve been training a model that uses batch normalization. While this normalization is crucial to speed up training, performance drops severely once I switch to eval instead of train mode. While browsing the PyTorch forums, I noticed that a lot of other people were having the same problem. The problem seems to be caused by the fact the running estimates are not reliable when using small batch sizes. For a lot of problems (e.g. segmentation), however, increasing the batch size is not feasible due to memory constraints. The authors of the batch normalization paper acknowledged this issue and wrote a follow-up paper about batch renormalization, a similar technique which should also work with smaller batches.

I was wondering if there were any plans to implement this batch renormalization in PyTorch? It seems like a feature that could be very helpful for all those users experiencing problems with regular batch normalization. I experimented with group, layer and instance normalization and neither of those performed as good as batch normalization, so batch renormalization may turn out to be very useful. I’m aware that it’s possible to implement this directly in Python, but other users already attempted this and reported significantly higher memory usage and time consumption. So, an implementation directly for CPU/GPU seems necessary.
Unfortunately, I’m afraid that I, like most users, lack the necessary skills to implement this feature myself for CPU/GPU in PyTorch.
Hence my question if there were any plans to implement this or if anyone has any tips for me to make my model perform good in both eval and train mode with small batch sizes.
Thanks in advance.

5 Likes

You could always try instance norm (unless you have really few features per channel), which takes away the difference between training and evaluation.
If you wanted to experiment with batch renorm, starting from current master makes that easier. In it’s native cuda kernels, we split the gathering the statistics from applying the transformation. So you could reuse the first and just adapt the pointwise transformation to use the amended formula. You can use the vanilla backward function (which is the “stop grad” wants you to ignore the amendments for the backward).
I would recommend doing this as a custom op or extension module + torch.autograd.Function.

Best regards

Thomas

We also would love Batch Renorm support as it seems to be very important for Continual Learning over small batch sizes!

4 Likes