Model.eval() gives incorrect loss for model with batchnorm layers

Is there any solution to this ? The solution provided in Performance highly degraded when eval() is activated in the test phase only solves partially the problem for me (gain of accuracy, but not the full deal).

This is a post of a problem related to this one I posted months ago Conflict between model.eval() and .train() with multiprocess training and evaluation. No new solution is outlined in here.

Same problem here, pytorch 0.4.1 gives much worse result even if the evaluation is done with the training data.

2 Likes

of course it is different… you are updating the running statistics every time you do a forward in training mode, hence changing the eval behavior.

I believe this is likely a pytorch bug rather than a model instability issue. I encounter this issue randomly when training/testing my models and the solution to the instability issue has been either removing all BatchNorm layers from my model or downgrading to pytorch 0.1.12. A reproducible example can be found in this repository for a recent CVPR18 paper. Discussion surrounding the issue can be found here. This code is consistently stable only with pytorch 0.1.12.

3 Likes

Hi, I have the same problem. Have you solved it yet?

1 Like

If the mean and variance of the training data is as mentioned non-stationary, which may arise from a small batch size, you could try nn.BatchNorm2d(out_channels, track_running_stats=False) this disables the running statistics of the batches and uses the current batch’s mean and variance to do the normalization I believe. It worked for me :slight_smile:

2 Likes

Thanks a lot! You solution may be worth a try. But I think the BatchNorm may be not the best choice for some applications.

I like this way, it works for me. Thank you very much.

I also meet the instability issue randomly.
Sometimes, use model.eval() works well, sometimes, model.eval() does not work.
This should be a bug, since it does not occur in tensorflow or Keras (all under same setting: same network, same batch size, etc…).

1 Like

Could you post a reproducible code snippet so that we could have a look?
Also, have a look at the reproducibility docs in case you haven’t seen them.

Replying for collecting

I had the same issue. The loss I was obtaining was Nan. Changing momentum=0 to 0<momentum<1 allowed me to get rid of this problem. Thanks !

This code snippet reproduces the issue for me on pytorch 1.1.0

import torch
import torch.nn as nn
import torch.optim as optim        
                                                                       
class SimpleNet(nn.Module):
    def __init__(self, image_size_total):
        super(SimpleNet, self).__init__()
                           
        self.conv1 = nn.Conv2d(1, 64, 3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.bn1 = nn.BatchNorm2d(64)
        self.max_pool1 = nn.MaxPool2d(2)
                                                               
        self.fc1 = nn.Linear((image_size_total//4) * 64, 2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.bn1(x)
        x = self.max_pool1(x)

        x = x.view(-1, self.num_flat_features(x))
        x = self.fc1(x)
        return x


    def num_flat_features(self, x):
        size = x.size()[1:] # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

width = 64
height = 64
network = SimpleNet(width* height)

batchone = torch.ones([4, 1, 64, 64], dtype=torch.float, device=torch.device("cuda:1"))
outputone = torch.tensor([.5, .5]).to(torch.device("cuda:1"))
batchtwo = torch.randn([4, 1, 64, 64], dtype=torch.float, device=torch.device("cuda:1"))
outputtwo = torch.tensor([.01, 1.0]).to(torch.device("cuda:1"))

def train_net(net, batch, output):
        net.train()

        optimizer = optim.SGD(net.parameters(), 0.0001)
        criterion = nn.MSELoss()

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        netoutput = net(batch)
        loss = criterion(netoutput, output)
        loss.backward()
        optimizer.step()
        return float(loss)

def evaluate_batch(net, batch, output, shouldeval):
        if shouldeval:
            net.eval()
        else:
            net.train()

        criterion = nn.MSELoss()

        # forward + backward + optimize
        netoutput = net(batch)
        loss = criterion(netoutput, output)
        return float(loss)

network.to(torch.device("cuda:1"))

for i in range(100):
    print("t loss1:", train_net(network, batchone, outputone))
    print("t loss2:", train_net(network, batchtwo, outputone))

print("v loss1:", evaluate_batch(network, batchone, outputone, True))
print("v loss2:", evaluate_batch(network, batchtwo, outputone, True))
print("train v loss1:", evaluate_batch(network, batchone, outputone, False))
print("train vv loss2:", evaluate_batch(network, batchtwo, outputone, False))

If i remove the batchnorm line it removes the discrepancy

Somewhere i saw the suggestion of using momentum = 0 instead of setting eval on the BatchNorm2d module.
Using this code, it removes the descripency

def evaluate_batch(net, batch, output, shouldeval):
        if shouldeval:
            net.eval()
            net.bn1.train()
            net.bn1.momentum = 0.0
        else:
            net.train()
...
        #before returning
        net.bn1.momentum = 0.1

How large is the difference for the runs?

Its pretty huge. with net.train(): 7e-8 with net.eval(): 99.3

Here is the truncated output (btw i fixed a bug in the original repro code, where the input batchsize did not match the output batchsize, now batchone = torch.ones[2, 1, 64, 64]… instead of [4, 1, 64, 64])

EDIT: If you do test the repro code, i suggest running it multiple times, the network sometimes explodes or does not converge as well depending on the random initialization. Validating using net.eval() on batchnorm2d is consistently worse than validation net.train() though. So there definitely seems to be a bug.

t loss1: 0.12320221960544586
t loss2: 0.09333723783493042
t loss1: 0.24439916014671326
t loss2: 0.5218815207481384
t loss1: 0.6079472899436951
t loss2: 3.729689598083496
t loss1: 2.112783908843994
t loss2: 28.2861328125
t loss1: 10.315664291381836
t loss2: 220.0238037109375
t loss1: 62.95467758178711
t loss2: 1715.936279296875
t loss1: 380.442138671875
t loss2: 12547.478515625
t loss1: 537.3038330078125
t loss2: 53827.6484375
t loss1: 10580.1064453125
t loss2: 484.6171875
t loss1: 625.4744873046875
t loss2: 174.572509765625
t loss1: 340.2962646484375
t loss2: 110.48974609375
t loss1: 199.84153747558594
t loss2: 81.0360107421875
t loss1: 124.49140167236328
t loss2: 59.60415267944336
t loss1: 80.25579071044922
t loss2: 44.34978485107422
t loss1: 53.18634796142578
t loss2: 33.68086624145508
t loss1: 36.14666748046875
t loss2: 26.259246826171875
t loss1: 25.240781784057617
t loss2: 20.230873107910156
t loss1: 17.960424423217773
t loss2: 15.522289276123047
t loss1: 12.958020210266113
t loss2: 11.859451293945312
....network converges....
t loss1: 8.753518159210216e-07
t loss2: 3.449472387728747e-06
t loss1: 7.535315944551257e-07
t loss2: 3.1326808311860077e-06
t loss1: 6.516746680063079e-07
t loss2: 2.851738372555701e-06
t loss1: 5.662100193148945e-07
t loss2: 2.603399707368226e-06
t loss1: 4.90573881961609e-07
t loss2: 2.3800134840712417e-06
t loss1: 4.2635625163711666e-07
t loss2: 2.181120635214029e-06
t loss1: 3.708064468810335e-07
t loss2: 2.0018892428197432e-06
t loss1: 3.229307878882537e-07
t loss2: 1.8403325157123618e-06
t loss1: 2.822781368649885e-07
t loss2: 1.6948088159551844e-06
t loss1: 2.4606848114672175e-07
t loss2: 1.5627991842848132e-06
t loss1: 2.1536402527999599e-07
t loss2: 1.4432781654249993e-06
t loss1: 1.8830129988600675e-07
t loss2: 1.3343917544261785e-06
t loss1: 1.6481742193263926e-07
t loss2: 1.2349712505965726e-06
t loss1: 1.4438909090586094e-07
t loss2: 1.1446034022810636e-06
t loss1: 1.2657532977300434e-07
t loss2: 1.0619536396916374e-06
t loss1: 1.1093613494495003e-07
t loss2: 9.862114893621765e-07
t loss1: 9.728184124924155e-08
t loss2: 9.167142138721829e-07
t loss1: 8.523475969468564e-08
t loss2: 8.52758375913254e-07
v loss1: 99.31019592285156
v loss2: 15.398092269897461
train v loss1: 7.480510788582251e-08
train vv loss2: 7.63321224894753e-07

The high validation loss is due to the wrong estimates of the running stats.
Since you are feeding a constant tensor (batchone: mean=1, std=0) and a random tensor (batchtwo: mean~=0, std~=1), the running estimates will be shaky and wrong for both inputs.

During training the current batch stats will be used to compute the output, so that the model might converge.
However, during evaluation the batchnorm layer tries to normalize both inputs with skewed running estimates, which yields the high loss values.
Usually we assume that all inputs are from the same domain and thus have approx. the same statistics.

If you set track_running_stats=False in your BatchNorm layer, the batch statistics will also be used during evaluation, which will reduce the eval loss significantly.

5 Likes

Oh I thought batchnorm would use the running statistics during training as well as validation.

I’ve seen this issue crop up with networks that i thought had images all from the same domain. Does that mean that batchnorm failed to actually capture the input statistics in those cases?

Not really. Since you are feeding samples from two different distributions (mean=0 and mean=1), the BatchNorm layer will in fact try to capture the statistics from “this dataset” as mean~=0.5.
(Since it’s using an exponential moving average you most likely see a bias towards the stats from the last used batch.)

Since you are not using samples from the same distribution, this will yield a bad validation performance.

I meant in the cases where all the images were in the same domain. I was trying to replicate that situation with the example code.

I ran into this issue again with a different network (only happened when reloading a checkpoint) and I think it might have been due to using the default momentum value for batchnorm2d. After retraining from a checkpoint and changing the momentum to 0.01 it yielded reasonable validation results. I think the high default momentum value caused the mean and variance to be too sensitive to changes within the image domain for my network architecture.