Hi,
I spent a day debugging this, and thought I’d share my finding about batch normalization seemingly overfitting. Here is my setup:
I have 24*7 = 168 models, for each hour of a week, with a few hundreds of thousand samples to train each. Out of 168 models, 166 models trained with consistent high accuracy, one was mediocre, and one was overfitting badly (training loss at 1E-5, test loss 0.5). The architecture used batch normalization in fully connected layers.
I tried
- different learning rates,
- different few batch sizes (128, 1024, 2048, 4096, 8192, 16384)
- shuffling the data
all in vain. When I removed a few samples from the data set, overfitting disappeared, but it does not really depend on which samples I removed. Then, it turned out that my trainin g data set has 196609 = 16384*12 + 1 samples. With PyTorch’s dataloader (http://pytorch.org/docs/_modules/torch/utils/data/dataloader.html) and any batch size of size 2^n for n <= 15 (until 32768) the last batch would be exactly 1 element. The way running averages are computed resulted in the variance of BatchNorm1d which is basically unusable.
In the training data set with mediocre performance there were 180227 = 16384*11 + 3 samples.
The solution was to accurately split into the training and testing data set so that all batches in the training data set have the same specified size. But something more robust is required so that BatchNorm is less fragile here:
- either make that all batches fed into BatchNorm have the same size and issue error/warning otherwise
- or compute running average while taking the batch size into consideration.
I’d be happy to propose a patch, but would hear opinions first — or probably this was already covered earlier.
David