Model breaks in evaluation mode

Hi, I am following this procedure:

  1. Train a network with train.py, and save the model with torch.save(model.state_dict(), 'model.pth')
  2. Next I load the model in classify.py with model.load_state_dict(torch.load(opt.model)) and set model.eval().

However, now I notice the model gives entirely different results if I call .eval() or not. The model includes a couple of BatchNorm2d and Dropout layers, and performs fine without calling .eval(), but I was hoping for a slight improvement in performance with .eval() activated. What are thinks to look out for when encountering this problem?

BatchNorm and Dropout layers have different behavior in training and evaluation modes. Try to narrow down your problem by setting all the layers to “eval” mode except your BatchNorm layers. Then try the same thing with the Dropout layers.

BatchNorm uses a running average of mean and variance during training to compute the statistics used during eval mode. Check the running_mean and running_var properties on your BatchNorm layers.

One thing I’ve seen mess up BatchNorm is sending in uninitialized inputs in train mode. For example:

output_size = network(torch.FloatTensor(16, 3, 224, 224).cuda()).size() # BAD!

The uninitialized input can mess up the running averages, especially if it happens to contain NaNs.

2 Likes

I have narrowed down the problem to the BatchNorm layers. Additionally I plotted the running_mean and running_var properties, and they seem fine (no NaNs for example.).

My guess is that it is related to the way I am sampling while training. I am using triplet loss, which means the same model is used 3 times on three batches, each containing samples of a single label.

anchor, positive, negative = model(input1), model(input2), model(input3)
loss = F.triplet_margin_loss(anchor, positive, negative) 

Taking MNIST as an example, this means the mini-batch input1 at each iteration contains randomly picked samples from a single label (for example only 9’s). input2 contains the other samples from the same label as input1. Finally, input3 contains samples from another label, for example only contains 4's.

Will mini-batches like these cause problems with BatchNorm properties? Should I switch to another sampling strategy?

1 Like

Figured it out, it did have to do with the sampling strategy.

Since BatchNorm requires mini-batches as a proxy to the population statistics the triplet samples should be selected randomly across many classes, instead of single-class batches.

3 Likes

I seem to have the same issue in my binary classification. When evaluating in .train() mode my model gets a high accuracy on my test data. Whereas, when I set the model to .eval() mode to switch dropout mode etc. my model always outputs class 0.

Switching everything but the BatchNorm layers to eval() mode gives me a good prediction. Unlike @bartholssthoorn my mini-batches contain already both classes (i.e. shuffled) so that would not solve it.

My running_means do not contain any NaNs.

What else can I check to figure out why the batchnorm layers are preventing my evaluation mode?

How large is your batch size? A small batch size might yield noisy estimates of the running stats.

16, same as I used for training. Using a smaller batchsize (4 or 1) didn’t help. I’m using 3D data so I cannot really increase it above 16.

Freitag, 28 September 2018, 00:03vorm. +02:00 von ptrblck noreply@discuss.pytorch.org:

You could try to tune the momentum hyperparameter of BatchNorm or alternatively use GroupNorm, which should work better using small batch sizes.

Finally, I have figured out the issue lay with the preprocessing of my data - it wasn’t exactly the same at test time as during training. A different model I have trained without batchnorm showed that it wouldn’t work on the test data at all either. I find it quite interesting that the model with batchnorm is actually able to perform quite well in train() mode, even though the normalization is different, causing the model to fail in eval() mode.

Thanks for the pointer to GroupNorm, I haven’t been aware of it but will surely check it out!