Model.eval() gives incorrect loss for model with batchnorm layers

This code snippet reproduces the issue for me on pytorch 1.1.0

import torch
import torch.nn as nn
import torch.optim as optim        
                                                                       
class SimpleNet(nn.Module):
    def __init__(self, image_size_total):
        super(SimpleNet, self).__init__()
                           
        self.conv1 = nn.Conv2d(1, 64, 3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.bn1 = nn.BatchNorm2d(64)
        self.max_pool1 = nn.MaxPool2d(2)
                                                               
        self.fc1 = nn.Linear((image_size_total//4) * 64, 2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.bn1(x)
        x = self.max_pool1(x)

        x = x.view(-1, self.num_flat_features(x))
        x = self.fc1(x)
        return x


    def num_flat_features(self, x):
        size = x.size()[1:] # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

width = 64
height = 64
network = SimpleNet(width* height)

batchone = torch.ones([4, 1, 64, 64], dtype=torch.float, device=torch.device("cuda:1"))
outputone = torch.tensor([.5, .5]).to(torch.device("cuda:1"))
batchtwo = torch.randn([4, 1, 64, 64], dtype=torch.float, device=torch.device("cuda:1"))
outputtwo = torch.tensor([.01, 1.0]).to(torch.device("cuda:1"))

def train_net(net, batch, output):
        net.train()

        optimizer = optim.SGD(net.parameters(), 0.0001)
        criterion = nn.MSELoss()

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        netoutput = net(batch)
        loss = criterion(netoutput, output)
        loss.backward()
        optimizer.step()
        return float(loss)

def evaluate_batch(net, batch, output, shouldeval):
        if shouldeval:
            net.eval()
        else:
            net.train()

        criterion = nn.MSELoss()

        # forward + backward + optimize
        netoutput = net(batch)
        loss = criterion(netoutput, output)
        return float(loss)

network.to(torch.device("cuda:1"))

for i in range(100):
    print("t loss1:", train_net(network, batchone, outputone))
    print("t loss2:", train_net(network, batchtwo, outputone))

print("v loss1:", evaluate_batch(network, batchone, outputone, True))
print("v loss2:", evaluate_batch(network, batchtwo, outputone, True))
print("train v loss1:", evaluate_batch(network, batchone, outputone, False))
print("train vv loss2:", evaluate_batch(network, batchtwo, outputone, False))

If i remove the batchnorm line it removes the discrepancy

Somewhere i saw the suggestion of using momentum = 0 instead of setting eval on the BatchNorm2d module.
Using this code, it removes the descripency

def evaluate_batch(net, batch, output, shouldeval):
        if shouldeval:
            net.eval()
            net.bn1.train()
            net.bn1.momentum = 0.0
        else:
            net.train()
...
        #before returning
        net.bn1.momentum = 0.1

How large is the difference for the runs?

Its pretty huge. with net.train(): 7e-8 with net.eval(): 99.3

Here is the truncated output (btw i fixed a bug in the original repro code, where the input batchsize did not match the output batchsize, now batchone = torch.ones[2, 1, 64, 64]… instead of [4, 1, 64, 64])

EDIT: If you do test the repro code, i suggest running it multiple times, the network sometimes explodes or does not converge as well depending on the random initialization. Validating using net.eval() on batchnorm2d is consistently worse than validation net.train() though. So there definitely seems to be a bug.

t loss1: 0.12320221960544586
t loss2: 0.09333723783493042
t loss1: 0.24439916014671326
t loss2: 0.5218815207481384
t loss1: 0.6079472899436951
t loss2: 3.729689598083496
t loss1: 2.112783908843994
t loss2: 28.2861328125
t loss1: 10.315664291381836
t loss2: 220.0238037109375
t loss1: 62.95467758178711
t loss2: 1715.936279296875
t loss1: 380.442138671875
t loss2: 12547.478515625
t loss1: 537.3038330078125
t loss2: 53827.6484375
t loss1: 10580.1064453125
t loss2: 484.6171875
t loss1: 625.4744873046875
t loss2: 174.572509765625
t loss1: 340.2962646484375
t loss2: 110.48974609375
t loss1: 199.84153747558594
t loss2: 81.0360107421875
t loss1: 124.49140167236328
t loss2: 59.60415267944336
t loss1: 80.25579071044922
t loss2: 44.34978485107422
t loss1: 53.18634796142578
t loss2: 33.68086624145508
t loss1: 36.14666748046875
t loss2: 26.259246826171875
t loss1: 25.240781784057617
t loss2: 20.230873107910156
t loss1: 17.960424423217773
t loss2: 15.522289276123047
t loss1: 12.958020210266113
t loss2: 11.859451293945312
....network converges....
t loss1: 8.753518159210216e-07
t loss2: 3.449472387728747e-06
t loss1: 7.535315944551257e-07
t loss2: 3.1326808311860077e-06
t loss1: 6.516746680063079e-07
t loss2: 2.851738372555701e-06
t loss1: 5.662100193148945e-07
t loss2: 2.603399707368226e-06
t loss1: 4.90573881961609e-07
t loss2: 2.3800134840712417e-06
t loss1: 4.2635625163711666e-07
t loss2: 2.181120635214029e-06
t loss1: 3.708064468810335e-07
t loss2: 2.0018892428197432e-06
t loss1: 3.229307878882537e-07
t loss2: 1.8403325157123618e-06
t loss1: 2.822781368649885e-07
t loss2: 1.6948088159551844e-06
t loss1: 2.4606848114672175e-07
t loss2: 1.5627991842848132e-06
t loss1: 2.1536402527999599e-07
t loss2: 1.4432781654249993e-06
t loss1: 1.8830129988600675e-07
t loss2: 1.3343917544261785e-06
t loss1: 1.6481742193263926e-07
t loss2: 1.2349712505965726e-06
t loss1: 1.4438909090586094e-07
t loss2: 1.1446034022810636e-06
t loss1: 1.2657532977300434e-07
t loss2: 1.0619536396916374e-06
t loss1: 1.1093613494495003e-07
t loss2: 9.862114893621765e-07
t loss1: 9.728184124924155e-08
t loss2: 9.167142138721829e-07
t loss1: 8.523475969468564e-08
t loss2: 8.52758375913254e-07
v loss1: 99.31019592285156
v loss2: 15.398092269897461
train v loss1: 7.480510788582251e-08
train vv loss2: 7.63321224894753e-07

The high validation loss is due to the wrong estimates of the running stats.
Since you are feeding a constant tensor (batchone: mean=1, std=0) and a random tensor (batchtwo: mean~=0, std~=1), the running estimates will be shaky and wrong for both inputs.

During training the current batch stats will be used to compute the output, so that the model might converge.
However, during evaluation the batchnorm layer tries to normalize both inputs with skewed running estimates, which yields the high loss values.
Usually we assume that all inputs are from the same domain and thus have approx. the same statistics.

If you set track_running_stats=False in your BatchNorm layer, the batch statistics will also be used during evaluation, which will reduce the eval loss significantly.

5 Likes

Oh I thought batchnorm would use the running statistics during training as well as validation.

I’ve seen this issue crop up with networks that i thought had images all from the same domain. Does that mean that batchnorm failed to actually capture the input statistics in those cases?

Not really. Since you are feeding samples from two different distributions (mean=0 and mean=1), the BatchNorm layer will in fact try to capture the statistics from “this dataset” as mean~=0.5.
(Since it’s using an exponential moving average you most likely see a bias towards the stats from the last used batch.)

Since you are not using samples from the same distribution, this will yield a bad validation performance.

I meant in the cases where all the images were in the same domain. I was trying to replicate that situation with the example code.

I ran into this issue again with a different network (only happened when reloading a checkpoint) and I think it might have been due to using the default momentum value for batchnorm2d. After retraining from a checkpoint and changing the momentum to 0.01 it yielded reasonable validation results. I think the high default momentum value caused the mean and variance to be too sensitive to changes within the image domain for my network architecture.

Lowering the momentum might help in situations, where the data is quite noisy.
So it’s good to hear you got reasonable results for your validation set by changing it! :slight_smile:

In the example code, you were feeding samples from two domains, so that the running estimates took values of their average (which is out of domain now).

Yup, I understood what you wrote earlier. I didn’t succeed in replicating my actual issue with my example code above.

How to do this? can we do it with “with torch.no_grad()?”

Yes, you can call model.train() and perform some forward passes in a torch.no_grad() block.

1 Like

Thanks for your solution. Recently I tried MnasNet pretrained model that has this problem. The momentum in that model is small (~3e-4). Beware of that if you are using that model.

Hi, I had similar issues and one thing I realized was that I defined one batch-norm layer and I used it after every layer. This might be the error that you might be making

For example:
self.batch_norm_hidden = nn.BatchNorm1d(num_hidden_nodes)

and then later:

for layer in layers:
x = layer(x)
x = self.activation_fn(x)
x = self.batch_norm_hidden(x)

This is obviously wrong as the same batch_norm_hidden is used everywhere. You need to define a new batch_norm for every layer (Otherwise the running stats are shared across the layers)

Also I think the momentum for the batchnorm layer is defined differently compared to the optimizers. It might be helpful to check the description as to how new batch_statistics are weighted

I had a similar problem, and setting track_running_stats=False fixed it for me (UNet, no dropout, just batchorm). But I still don’t understand why… Firstly, does setting that flag in any way affect the training, or only the evaluation process?
And secondly, I tried (for debugging purposes) training and validating (and evaluating after training) on the same single image. And also on the same single batch of 8 images. In both cases the problem (discrepancies between validation and training loss, or between network outputs under net.eval() vs. under net.train()) persisted.
But if I understand your explanation for the use of that track_running_stats correctly, then in my case, it should not make a difference. Since all the data comes from the same distribution. As it’s literally the same in all stages (train, val and test)… So why does the flag still seem to help in my case?

1 Like

This argument affects the validation, as no running stats would be calculated and all validation inputs will be normalized using the batch statistics.

The running stats will be updated using the momentum as described in the docs, so you would probably need more forward passes to let the running stats converge towards the batch stats.

Thanks for the answer.
Just to be sure, during training, when running stats are computed, it feels that with a single batch and momentum 0.1 it should definitely converge to the batch statistics after 100 epochs… Right?
And it still hasn’t in my case (Generally, the discrepancies do become smaller, but we’re still talking for example a Dice score of 0.6 vs 0.4 or sometimes more…)
So I wonder, otherwise perhaps I have another issue in the code, would you generally expect it to take so long (for convergence to batch statistics)?

Yes, I would assume that the running stats converge after 100 iterations as is also shown here:

x = torch.randn(1, 3, 100, 100) * 5. + 7.
bn = nn.BatchNorm2d(3)

mean, var = x.mean([0, 2, 3]), x.var([0, 2, 3])

for _ in range(100):
    out = bn(x)

bn_mean, bn_var = bn.running_mean, bn.running_var

print('sample mean {}\nsample var {}'.format(mean, var))
> sample mean tensor([7.0567, 7.0104, 6.9218])
  sample var tensor([24.7264, 24.9284, 24.9140])

print('bn mean {}\nbn var{}'.format(bn_mean, bn_var))
> bn mean tensor([7.0565, 7.0102, 6.9216])
  bn vartensor([24.7258, 24.9278, 24.9134])

EDIT:

It depends a bit on your use case. Note that the intermediate activations are not static, since the parameters are updated in each iteration. This could also mean that the stats are changing and that the bn layers are tracking these changes, so you cannot directly assume that x iterations will make the stats converge perfectly.

1 Like

A possible solution is to set “track_running_stats=False” in traing and testing
that may work
In my model, I generated datas from a pool and the data fed to the network is different, so it seems that we shouldn’t trace the running mean/var and instead using the one calc from the batch

It does works,thanks a lot