I want to share my experience with this problem. I’m doing video deep learning (gesture detection), which is particularly demanding in terms of memory, at least on my private dataset.
I first encountered this problem in July 2018 Conflict between model.eval() and .train() with multiprocess training and evaluation
I usually have a batch size below 10, and that’s using multi gpu, so rather 3-5 examples per gpu.
I enabled BN in validation to have a real view on my validation metrics with
model.train(). But of course this leads to very poor results when I actually want to use my model on real world data.
So what exactly is happening ? Your model has been fitted to understand small batches : it has “overfitted the batch”. This is very counter-intuitive because we’re accustomed to think that samples seen in a batch are purely independently seen, or rather that batch normalization is just some kind of help and an augmentation. But this view is erroneous, especially on small batch size. If you look at the values that a normalized sample will take inside of a small batch, you can see that they will vary highly. In other words, the standard deviation of the values that your sample takes in random batches of your dataset is high. If you had a greater batch size, say 32, the distribution of values that your normalized sample would take after a BN layer, among random batches, would be much more narrow. I think this is a good view to what exactly is happening when I say that the model overfits the batch. Your model is trained to look at a much wider distribution of values, and is not especially smart of the subset of values that it will see when
model.eval() is set.
I’ve tested a few solutions, some of them outlined in previous answers.
Perform a few forward passes after the training with big batch sizes without gradient descent, and with
model.train() set. This WON’T work. Naturally, if you try to do that, your BN layers will change, while your other layers are frozen. But the problem isn’t that the batch norms metrics (std and mean) are wrong. It is that the actual mean and std of the dataset are bad. Your convolutions are actually set in a way that they need to see extreme values. Tuning your BNs this way will produce effects akin to making your network see purely grey images. It will be hard for it to decipher anything salient at every step of the forward pass.
Increase momentum of the BN. THis means that the means and stds “learned” will be much more stable during the process of training. With the same reasonning, you can uderstand why this won’t work. The training still sees the same widened distribution. But the if you set your BN so that during the training, they have more chance to capture the real means and std of the dataset, they will capture values that are not suited for your convolutions !
Skipping batches so that you artificially have a higher batch size. That is especially wrong. Because in this manner, exactly as before, you will have a higher chance of capturing the real means and stds of the dataset. But what is especially false, is the assumption that the forward and backward passes are happening differently then. If you have a higher BS, you can hope to reproduce this article results https://arxiv.org/pdf/1711.00489.pdf But it’s a matter of convergence, not of BN. With this tactic, your model will have a different approach to overfitting the batch (with less stochastic variability), but this will still be its goal.
Increasing your memory by adding gpus to your training setup. This is wrong and wasteful. The standard way of doing BN on most frameworks is “gpu specific”. The BN batch mean and std are computed with examples sitting in an individual gpu. That means that with two gpus, at forward time, you actually have two batch means, and stds, at each BN layer. Consequently, thinking that with 4 V100 you’ll solve the problem you had with one, is a really bad strategy for you and your wallet. What would work is having a gpu that is so large that it can fit 32 samples inside its individual memory. For my problem, even 32GB Tesla V100 didn’t cut it (it brought me to 20 samples per gpu, which is not bad, but I could still observe the bad effects of BNs). Nevertheless, depending on your situation, your should try this. Unfortunately, the only 32GB V100 I could use are the ones on the p3dn.16xlarge of aws, which has 8 of them, and is especially costly (and you won’t be able to keep it more than 1H in spot mode).
Using group normalization instead of batchnorms. So as advertised in this article https://arxiv.org/pdf/1803.08494.pdf this WORKS ! BUT it has some disadvantages, not the least of which is NO ONE THAT RELEASES PRETRAINED MODELS USES IT ! It sound exaggerated but this is really a hurtle for me, because as you might know, doing video deep learning without transferring knowledge is like trying to win the 100m olympic when you’re obese. I’m pretty much forced to perform all pretrainings myself, which is really tedious when you’re benchmarking many architectures of video. So, if you don’t use pretrained models, you might not care and so you should definitely do it. Yet in my experience, groupnorm also slows down the training, and demands more memory, and thus is a bit irritating. And don’t even think you can take a pretrained model using BNs and replace them with GN. It doesn’t work, and you might just as well throw away everything learned after the first BN.
THE SOLUTION I FOUND : BN synchronization works ! It means sharing stats between gpus at forward time so that there is only one mean and std computer per BN layer for the whole multi gpu setup. With this layer, your gpus will act as one. This will still require that you have a cumulative gpu memory high enough to hold 32 samples per batch, so it’s still a bit tedious. Still it’s cool to finally have a solution. Plus you can easily transfer from a BN model with this, it won’t bug, and it’s a proper transfer if you think about it !
By the way, if you use a pretrained model and encounter this BN problem, notice that you’re also using your pretrained model very poorly, because it (probably) has been trained to look at very different distributions of means and stds itself. In my experience it’s still better than nothing, but it’s pretty under-optimal.
Another word on why this might not be too problematic for you. If you can just use the exact same batch size as your training in validation, and make sure that your batches are exactly as random as in your training, and you’re setting
model.train(), you will pretty much have the best validation metrics you can. This is especially true if your test set is not “reality” but another dataset sitting on your hard drive. It’s not a satisfying solution for me because I want to use my model efficiently, by setting the batch size at the highest value I can, optimizing my model with fusion of layers, and (without going to much into details) feeding very similar data in each batch in inference. But if you don’t have these constraints, you’ll be alright.