Use global statistics of BatchNorm in training mode

I am working on a edge detection code and my batch_size=1 (due to GPU RAM’s limitation). I am using a pretrained ResNet as my backbone architecture. While training I want to use the global statistics of batch norm layer rather than batch statistics (as it won’t make much sense because of batch_size=1).

One option is to set model.eval() in my training part so the network will use the global statistics of the batch norm layer. Given my network doesn’t have any dropout layers, is it the right thing to do?

What if my network has dropout layers and I want to use global batch norm statistics of bath norm layer in training phase?

It should work by setting model.eval(), if you don’t have any other modules, which switch their behavior in train/eval.
However, a cleaner approach would be to just set the BatchNorm layers to eval.
If you have a model definition, something like this should work:

model = MyModel()

Thanks @ptrblck for your reply.

I don’t have any other modules, which switch their behavior in train/eval and set model.eval() in training phase. Now I am seeing Nan values in loss. However this is not the case when I set model.train() in my code.

Further I also tried to set individual batch norm layers to eval() mode (following your code snippet). Again I am observing Nan values.

Can you suggest anything that might be possibly going wrong?

That sounds strange. Can you check your input for NaN values with print((x!=x).any())?

I ran the code with model.eval(). After the last batch norm layer all the values are becoming Nan.However when I set model.train() the code is working fine.

Further I saw, the weights and bias parameters of the Batch Norm layers are going to Nan. This seems quite a weird behaviour.

At the following threads also people are trying to do the exact same thing. I followed all the solutions mentioned there but still no luck.

Hey Its too late to ask. Any luck with setting batchnorm layers to eval mode. Have tried the suggested solutions but goes to NaN.