Performance highly degraded when eval() is activated in the test phase

Unfortunately, no (although it depends on your dataset / architecture). If you enable model.eval, even with the same batch size, you might encounter the trouble that the stats learned by batchnorm layers don’t actually represent the batch stats your model saw in training.

I have a similar problem. The evaluation loss while using track_running_stats = True is enormous. The only solution is to set it to track_running_stats = False, but unfortunately, it causes that model cannot be evaluated on a batch_size = 1.Does the model calculate running_std and running_var in model.eval() , I thought that while track_running_stats = False there is no need for them to be computed. Could you pleas take a look at my post: Batch norm training mode error despite model.eval()

I had similar problems with my network, that was trained using 8 GPUs and a batch size of 64. When I wanted to perform inference using a single image, predictions were quite unstable. I have experimented with track_running_stats but it didn’t seem to help. I found the following solution from https://github.com/yjxiong/tsn-pytorch.

    net = CreateNetwork()
    net = torch.nn.DataParallel(net) #for multi-gpu
    state = torch.load(model_load_path, map_location=device)
    net.module.load_state_dict(state.get('net'), strict=False)
    net.to(device)
    net.eval()

    count = 0
    for m in net.modules():
        if isinstance(m, torch.nn.BatchNorm2d):
            count += 1 #skip the first BatchNorm layer in my ResNet50 based encoder
            if count >= 2:
                m.eval()
                m.weight.requires_grad = False
                m.bias.requires_grad = False

Now it all works!

Are you seeing any difference by using the for loop vs. just calling net.eval()?
If I understand the loop correctly, it would skip the very first batchnorm layer and set all others to eval mode (and also disabling gradient calculation, which shouldn’t be used during evaluation/testing anyway).
However, since net.eval() is called before, all layers should be in eval mode, so this loop shouldn’t change anything. (I might be missing something obvious)

I am having a similar problem with my network currently. My network is 3D and large, so batch sizes are small (between 1 and 4 depending on the experiment I am doing). I tried implementing

    net.eval()
    for m in net.modules():
        if isinstance(m, torch.nn.BatchNorm3d):
            m.track_running_stats = False

and

    net.eval()
    count = 0
    for m in net.modules():
        if isinstance(m, torch.nn.BatchNorm3d):
            count += 1
            if count >= 2:
                m.eval()
                m.weight.requires_grad = False
                m.bias.requires_grad = False

as outlined by other peoples responses, but neither gave me different results compared to when not using them. I have gone through my code to look for typos but I cant find anything (yet). I did train with track_running_stats = True, so would training again with it disabled help?

I also did try

net.eval()
for child in net.children():
    for ii in range(len(child)):
        if type(child[ii])==nn.BatchNorm3d:
            child[ii].track_running_stats = False

but it gave me errors about child not having a length.

Right now when I run the model in train mode during evaluation the results are significantly better than when using eval, so I am not quite sure what I am doing wrong.

I am at a complete loss with this discrepancy. From my understanding, the difference between net.train() and net.eval() is the behavior of track_running_states, correct (I am not using dropout, so that doesn’t apply here). Isn’t manually setting the nn.BatchNorm3d to have track_running_states = False the same as net.eval(), or what what I should have done during training?

As far as I can tell now when I run the evaluation of the model in net.train() the performance is significantly better than using net.eval() and the only difference I am doing in that toggle (going between train and eval modes). Therefore, it is clear the issue is the nn.BatchNorm3d, but I don’t know how to fix it. Any help would be greatly appreciated.

@kleingeo yes, the solutions suggested in this topic don’t work anymore since some time. In another post I have found the precise commit that broke them, see here

Basically changing track_running_stats after Batch norm has been created will not have any effect (or at least not the intended one).

But you can do it when you create it.
BatchNorm(..., track_running_stats=False). This will work in running batch norm as in Eval mode.

Try it and let me know :slight_smile:

3 Likes

I train my models a few times with track_running_stats = False and it does seem to improve performance, at least the performance discrepancy when running the validation in eval and train mode.

I got in the good habit to always test also on the training set. The accuracy should be approximately equal to the training accuracy, an any discrepancy would be due to BatchNorm and (much less likely) dropout.

That’s a good idea. I was trying that towards the end of this arch which lead me to figure out there was an issue with the BatchNorm, I just wasnt sure (until recently) how to fix the problem.

The following worked for me:

for m in model.modules():
    for child in m.children():
        if type(child) == nn.BatchNorm2d:
            child.track_running_stats = False

I also found that shuffling the data in inference mode improved my performance a lot, this might especially help if you have small batch sizes.

Hi! This fixed the bug for me thank you so much! :slight_smile:

I tried the above suggestions but found they did not work in my case.

I found the following fix worked for me — I simply added a custom train() method to my base network module which at eval time turns ON training for BatchNorm2d instances and sets their momentum to 0.0. It is confusing to me why this would cause any difference in behavior for batch norm since from the batch norm paper and PyTorch docs I would expect turning on training and setting momentum to 0.0 would result in no change to the internal estimates for the running mean and standard deviation. However, nevertheless, after applying this fix I get an order of magnitude decrease in loss at eval time. The fix is a bit wonky because it saves and restores the old momentum value so be forewarned that it will not work if one is dynamically changing the momentum elsewhere in one’s code.

My setup is CUDA 11.2, PyTorch 1.8.1, Windows. Hope that helps someone.

		def train(self, mode):
			"""
			Warning: weird workaround for some issue of means/stds in BatchNorm being accurate in train phase but having issues in eval phase.
			See: https://discuss.pytorch.org/t/performance-highly-degraded-when-eval-is-activated-in-the-test-phase/3323/61
			"""
			if not isinstance(mode, bool):
				raise ValueError("training mode is expected to be boolean")
			self.training = mode
			for module in self.modules():
				if isinstance(module, nn.BatchNorm2d):
					module.training = True
					if module.momentum != 0.0:
						module.orig_momentum = module.momentum
					if not mode:
						module.momentum = 0.0
					else:
						module.momentum = module.orig_momentum
				else:
					module.training = mode

This should be correct, but batchnorm layers in training mode are normalizing the input activations using the current input activation batch stats not the running estimates. This means that your results depend on the batch size and it could degrade e.g. if you are lowering the batch size or it could even raise an error in case the stats cannot be calculated from a single sample.

@ptrblck : Thanks for your comment. But are you sure your statement is correct that “batchnorm layers in training mode are normalizing the input activations using the current input activation batch stats not the running estimates?” That to me appears to be at odds with the PyTorch BatchNorm2d documentation: " Also by default, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during evaluation. The running estimates are kept with a default momentum of 0.1."

It is also a good point that batch size changes could create problems in general for these estimates, but it appears that my code is using all the same batch size (batch size 1), so it remains a mystery to me why my above snippet would make any difference. I’ll try to update this thread if I ever track down what the bug is.

Yes, I’m sure that the input activation stats will be used to normalize it during training mode while also the running stats are updated using these calculated input stats and the momentum. I don’t think these statements contradict each other.

Here is an example which shows that the input activation is normalized using their own stats instead of the running stats:

# create input with a defined mean and std
mean = 5.
std = 10.
x = torch.randn(10, 3, 224, 224) * std + mean

print('mean {}, std {}'.format(x.mean([0, 2, 3]), x.std([0, 2, 3])))
# > mean tensor([5.0125, 5.0295, 4.9645]), std tensor([ 9.9943, 10.0157,  9.9935])

# apply bn in training mode
bn = nn.BatchNorm2d(3)

print('running_mean {}, running_var {}'.format(bn.running_mean, bn.running_var))
# > running_mean tensor([0., 0., 0.]), running_var tensor([1., 1., 1.])

bn.train()

# normalize input activation using input stats and update running stats
output = bn(x)
print('mean {}, std {}'.format(output.mean([0, 2, 3]), output.std([0, 2, 3])))
# > mean tensor([-3.2676e-08, -5.8388e-09,  8.8647e-09], grad_fn=<MeanBackward1>), std tensor([1.0000, 1.0000, 1.0000], grad_fn=<StdBackward>)

print('running_mean {}, running_var {}'.format(bn.running_mean, bn.running_var))
# > running_mean tensor([0.5013, 0.5029, 0.4964]), running_var tensor([10.8887, 10.9315, 10.8870])

If the running stats were used during training then the output tensor would not have been normalized, as the initial running stats contain a zero mean and a unit variance.

@Valerio_Biscione Thank you!! fixed my low accuracy on eval, when using a smaller batch size.
In my case I didn’t have direct access to the model class, so I couldn’t initialize batch norm with track_running_stats=False. As you rightly mentioned the latest commit checks the batch norm stats to decide if it is in training mode or eval mode, so I set the mean and var variables in the batch norm to None and it worked out perfectly.

for m in model.modules():
    for child in m.children():
        if type(child) == nn.BatchNorm2d:
            child.track_running_stats = False
            child.running_mean = None
            child.running_var = None

model.eval()
4 Likes

I have encountered the same problem.
Simply say, the model seems trained well and the loss is as expected during training.
While testing, after setting model.eval(), the result seems bad and the loss is high.
Using model.train(), or set m.track_running_stats = False really improve the result, however, if i evaluate the model with batch_size=1, the result is bad again.
Then i check my code and find the batch_norm layer’s affine is set False, I think probably it causes the problem. Now I am retraining my model with affine=True. I will report the result.
BTW, if you encounter the same problem, and your batch_norm 's affine is False, I think it may be the reason.

I have same problem.
I try to overfit and train only one picture and validate it at the same time. The output of training is very good, the output of validation is bad. Then I set:
cudnn.deterministic = True
It worked. I speculated that the cudnnbatchmark algorithm causes the result of each forward to be different.

Thanks for your answer! This works for me! It has to be False when creating it, i.e. BatchNorm(…, track_running_stats=False).

This resolves my issue. I just applied this during the test time, it performs well. Thank you.