Batch normalization, batch size, and data loader's last batch


I spent a day debugging this, and thought I’d share my finding about batch normalization seemingly overfitting. Here is my setup:

I have 24*7 = 168 models, for each hour of a week, with a few hundreds of thousand samples to train each. Out of 168 models, 166 models trained with consistent high accuracy, one was mediocre, and one was overfitting badly (training loss at 1E-5, test loss 0.5). The architecture used batch normalization in fully connected layers.

I tried

  • different learning rates,
  • different few batch sizes (128, 1024, 2048, 4096, 8192, 16384)
  • shuffling the data

all in vain. When I removed a few samples from the data set, overfitting disappeared, but it does not really depend on which samples I removed. Then, it turned out that my trainin g data set has 196609 = 16384*12 + 1 samples. With PyTorch’s dataloader ( and any batch size of size 2^n for n <= 15 (until 32768) the last batch would be exactly 1 element. The way running averages are computed resulted in the variance of BatchNorm1d which is basically unusable.

In the training data set with mediocre performance there were 180227 = 16384*11 + 3 samples.

The solution was to accurately split into the training and testing data set so that all batches in the training data set have the same specified size. But something more robust is required so that BatchNorm is less fragile here:

  • either make that all batches fed into BatchNorm have the same size and issue error/warning otherwise
  • or compute running average while taking the batch size into consideration.

I’d be happy to propose a patch, but would hear opinions first — or probably this was already covered earlier.


1 Like

Hi @dtolpin,

thank you for sharing this interesting problem and the detailed analysis.
To me, your first option (making all batches the same size) sounds the one that is more reasonable in practice. Quite likely, you could just pick random samples to duplicate for this and be done with it.

I must admit that I am quite unsure whether I interpret pytorch’s momentum parameter correctly, but if it means something like alpha in

running_mean_estimate = alpha * running_mean_estimate + (1-alpha) * minibatch_mean,

I would expect something more like 0.9 rather than pytorch’s default of 0.1. So changing the momentum might help, too, in particular if your analysis for option 2 (use minibatch size in running average computation) is correct.

If you wanted to go down option 2, the other (and I would almost expect it to be the more significant) shortcoming of the batch normalization as described in Ioffe and Szegedy’s original article as Algorithm 1 is that during training, the mean and std are taken from the current minibatch. For very small minibatches, I would expect that to be disadvantageous and using a regularization like

regularized_mean_estimate = (actual_batchsize * minibatch_mean +  ((target_batchsize-actual_batchsize) * running_mean_estimate) / target_batchsize

regularized_variance_estimate = ((actual_batchsize-1) * minibatch_mean +  ((target_batchsize - actual_batchsize) * running_mean_estimate) / (target_batchsize-1)

to work much better. (You could have a fancy Bayesian thing to average them, too, and find out why and how my weights above are rubbish, but it might be a starting point.)

As I said above, in practice, I would probably go with amending the data to fill up the last minibatch. On the other hand, might be fun to see which of your suggestion for running mean/std estimate updates, the blanket momentum adjustment, and regularization in the training batch normalisation works best.

Best regards