Model.eval() gives incorrect loss for model with batchnorm layers

Nice, setting the momentum to 0.5 seems to make loss calculated using model.eval() similar to the loss computed using model.train(), but only after few epochs of differing results.

Does the second suggestion work only when we are loading the model for eval? It should not affect the loss calculation, correct or incorrect, if we are training the network and running validation every nth step.

I also met this problem in my project (See my answer at and In short, down-grading pytorch version to 0.1.12 will resolve the problem. But I really don’t know what happens to the BN implementation from 0.1.12 to the later versions …

1 Like

I replied on the issue, but running stats is unstable in nature with batch size only being 1.

Thanks for the reply! The training batch size is 6 instead of 1. Actually I have also tried later batch size (32) with other architectures (upsampling on ResNet18) but the bug remains. My major question is I don’t understand why pytorch 0.1.12 works while >= 0.2 does not.


I think this is not about the momentum. I have the same problem. when I call


every call of model(input) is almost the same if it is after model.train() , but differs with what follows model.eval().


Is there any solution to this ? The solution provided in Performance highly degraded when eval() is activated in the test phase only solves partially the problem for me (gain of accuracy, but not the full deal).

This is a post of a problem related to this one I posted months ago Conflict between model.eval() and .train() with multiprocess training and evaluation. No new solution is outlined in here.

Same problem here, pytorch 0.4.1 gives much worse result even if the evaluation is done with the training data.


of course it is different… you are updating the running statistics every time you do a forward in training mode, hence changing the eval behavior.

I believe this is likely a pytorch bug rather than a model instability issue. I encounter this issue randomly when training/testing my models and the solution to the instability issue has been either removing all BatchNorm layers from my model or downgrading to pytorch 0.1.12. A reproducible example can be found in this repository for a recent CVPR18 paper. Discussion surrounding the issue can be found here. This code is consistently stable only with pytorch 0.1.12.


Hi, I have the same problem. Have you solved it yet?

1 Like

If the mean and variance of the training data is as mentioned non-stationary, which may arise from a small batch size, you could try nn.BatchNorm2d(out_channels, track_running_stats=False) this disables the running statistics of the batches and uses the current batch’s mean and variance to do the normalization I believe. It worked for me :slight_smile:

1 Like

Thanks a lot! You solution may be worth a try. But I think the BatchNorm may be not the best choice for some applications.

I like this way, it works for me. Thank you very much.

I also meet the instability issue randomly.
Sometimes, use model.eval() works well, sometimes, model.eval() does not work.
This should be a bug, since it does not occur in tensorflow or Keras (all under same setting: same network, same batch size, etc…).

1 Like

Could you post a reproducible code snippet so that we could have a look?
Also, have a look at the reproducibility docs in case you haven’t seen them.

Replying for collecting

I had the same issue. The loss I was obtaining was Nan. Changing momentum=0 to 0<momentum<1 allowed me to get rid of this problem. Thanks !

This code snippet reproduces the issue for me on pytorch 1.1.0

import torch
import torch.nn as nn
import torch.optim as optim        
class SimpleNet(nn.Module):
    def __init__(self, image_size_total):
        super(SimpleNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, 3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.bn1 = nn.BatchNorm2d(64)
        self.max_pool1 = nn.MaxPool2d(2)
        self.fc1 = nn.Linear((image_size_total//4) * 64, 2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.bn1(x)
        x = self.max_pool1(x)

        x = x.view(-1, self.num_flat_features(x))
        x = self.fc1(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:] # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

width = 64
height = 64
network = SimpleNet(width* height)

batchone = torch.ones([4, 1, 64, 64], dtype=torch.float, device=torch.device("cuda:1"))
outputone = torch.tensor([.5, .5]).to(torch.device("cuda:1"))
batchtwo = torch.randn([4, 1, 64, 64], dtype=torch.float, device=torch.device("cuda:1"))
outputtwo = torch.tensor([.01, 1.0]).to(torch.device("cuda:1"))

def train_net(net, batch, output):

        optimizer = optim.SGD(net.parameters(), 0.0001)
        criterion = nn.MSELoss()

        # zero the parameter gradients

        # forward + backward + optimize
        netoutput = net(batch)
        loss = criterion(netoutput, output)
        return float(loss)

def evaluate_batch(net, batch, output, shouldeval):
        if shouldeval:

        criterion = nn.MSELoss()

        # forward + backward + optimize
        netoutput = net(batch)
        loss = criterion(netoutput, output)
        return float(loss)"cuda:1"))

for i in range(100):
    print("t loss1:", train_net(network, batchone, outputone))
    print("t loss2:", train_net(network, batchtwo, outputone))

print("v loss1:", evaluate_batch(network, batchone, outputone, True))
print("v loss2:", evaluate_batch(network, batchtwo, outputone, True))
print("train v loss1:", evaluate_batch(network, batchone, outputone, False))
print("train vv loss2:", evaluate_batch(network, batchtwo, outputone, False))

If i remove the batchnorm line it removes the discrepancy

Somewhere i saw the suggestion of using momentum = 0 instead of setting eval on the BatchNorm2d module.
Using this code, it removes the descripency

def evaluate_batch(net, batch, output, shouldeval):
        if shouldeval:
            net.bn1.momentum = 0.0
        #before returning
        net.bn1.momentum = 0.1

How large is the difference for the runs?