Getting different outputs from batchnorm network with different batch sizes at test time

I trained a DenseNet to compute image embeddings of size 256, to match images.
At test time, when I use different batch sizes (64, 256, 512), the network outputs (embedding) are slightly different (I do model.eval()).
The difference seems to be proportional to the difference in the batch size.

I suspect, this is due to batchnorm (track_running_stats).

Update: Indeed, it is related to batchnorm’s track_running_stats.

If I set track_running_stats to False, but use a large batch size (512), it works well (high matching accuracy). When I use small batch size (8), the accuracy drops a few percent.

If I leave track_running_stats as True (default), even though the outputs for different batch sizes differ slightly, the matching accuracy is consistent across batch sizes, and slightly lower than track_running_stats = False with large batch size.

How to solve this?
I should get the same output regardless of the batch size.
Is it possible with a network with batchnorm?

PyTorch v1.4

Update 2: Typical Use case Scenario:

  • I train a batchnorm CNN network with a single embedding layer of size 256 with metric loss, and large batch size (256).
  • I also validate/test using a large batch size (512)
  • Got a good model, let me deploy it. Compute the embeddings for millions of images in the database using a large batch size (512). Great!
  • At test time, though, the network receives a single image at a time, in real-time! Oooops! The batch size is now 1. Even if I feed one of the images from the database, I do not get the same embeddings! If I match them, I do not get a distance of 0! (We cannot explain this to an outsider).
  • Hence, the problem.

Hi,
do you use model.eval() before you pass the inputs to the network ?

Yes, as I said in my post.

The problem is related to the batch statistics. I would like to see if there is a good solution that works.

This won’t be possible with track_running_stats=False, since the batchnorm layers will calculate the current batch statistics, which will vary based on the batch size.
However, track_running_stats=True should give you the same outputs (up to floating point precision, if non-deterministic methods were used or if you haven’t used the setup from the reproducibility docs).
How large is the absolute difference with track_running_stats=True?

Thanks Peter.

I have updated my post with a typical use case scenario (also how I use it).

Indeed, based on my preliminary experiments, it is not possible to get the same embeddings with track_running_stats=False. Now, trying to find a solution with track_running_stats=True.

The distances between embeddings of the same images with different batch sizes are summarized below.

I wish, it was possible to also save the running mean/std of batchnorm layers from training and use them at test time with a small momentum. This would mitigate the problem.

A workaround might be to run the model on the training set with large batch size, right before switching to test. I have not tried this yet.

Do you have a better idea?

If you are loading the state_dict, the running stats will also be restored.
Since the validation and test cases are executed via model.eval(), the momentum term won’t be used. Instead the running stats will be used directly to normalize the data.

Could you post the model architecture, so that we could have a look?

If the running stats are saved and re-stored and used at test time (the docs also say so), then there should not be any difference, regardless of the batch size. I use torch.save('model.pkl', model) and model = torch.load('model.pkl').

Here is the model, which is just a DenseNet:

class DenseNet(nn.Module):

   def __init__(self, layers=169, pretrained=True, emb_size=256):
       super(DenseNet, self).__init__()

       if layers == 121:
           self.model = models.densenet121(pretrained=pretrained)
           fc_in = 1024
       elif layers == 161:
           self.model = models.densenet161(pretrained=pretrained)
           fc_in = 2208
       elif layers == 169:
           self.model = models.densenet169(pretrained=pretrained)
           fc_in = 1664
       elif layers == 201:
           self.model = models.densenet201(pretrained=pretrained)
           fc_in = 1920

       self.model.classifier = nn.Linear(fc_in, emb_size, bias=False)


   def forward(self, x, norm=True):
       x = self.model.features(x)
       x = F.relu(x, inplace=True)
       x = F.adaptive_avg_pool2d(x, (1, 1))
       x = x.view(x.size(0), -1)

       x = self.model.classifier(x)
       if norm:
           x = F.normalize(x)

       return x

I used it like this:

model = DenseNet(layers=169, pretrained=True, emb_size=256)
model = model.cuda()
model.train()
torch.set_grad_enabled(True)
# train
torch.save('model.pkl', model)

# Test time
model = torch.load('model.pkl')
model.eval()
torch.set_grad_enabled(False)
# do test

I suppose model.train()/model.eval() on the DenseNet object recursively calls train/eval on all the modules it contains. If it is not the case, then that is the problem [update: no, model.train()/eval() works as expected].

Update: Here are some numbers, L1 distances between the DenseNet 169 embeddings of the same 5 images (out of 512) using different batch sizes in model.eval() mode.

# load the model, switch to eval(), run it on 512 images with different batch sizes, and compute L1 distances in memory, without saving to disk.
# batch sizes: 512 vs 1 - five runs, seems to be non-deterministic
tensor([1.2099e-05, 1.2104e-05, 5.7479e-06, 1.1539e-05, 9.9048e-06])
tensor([1.2099e-05, 1.2104e-05, 5.7479e-06, 1.1539e-05, 9.9048e-06])
tensor([1.4106e-05, 1.2881e-05, 6.3117e-06, 1.1826e-05, 1.0850e-05])
tensor([1.4106e-05, 1.2881e-05, 6.3117e-06, 1.1826e-05, 1.0850e-05])
tensor([1.2099e-05, 1.2104e-05, 5.7479e-06, 1.1539e-05, 9.9048e-06])

# batch sizes: 512 vs 1
# Forward 10 batches  of 512 in train mode, then switch to eval mode
tensor([1.3889e-05, 1.3013e-05, 7.1302e-06, 1.2253e-05, 1.3526e-05])

# batch sizes: 512 vs 4
tensor([1.4723e-05, 1.4445e-05, 8.1474e-06, 1.2522e-05, 1.2616e-05])

# batch sizes: 512 vs 64
tensor([1.4181e-05, 1.3389e-05, 6.6453e-06, 1.1660e-05, 1.1504e-05])

# batch sizes: 512 vs 256
tensor([3.9932e-06, 4.5871e-06, 2.6375e-06, 3.3856e-06, 4.4151e-06])

# batch sizes: 512 vs 512
tensor([0., 0., 0., 0., 0.])

# ImageNet Pre-trained DenseNet 169 with new non-trained embedding layer
# batch sizes: 512 vs 1
tensor([1.6982e-05, 1.5111e-05, 1.6363e-05, 1.5956e-05, 1.8549e-05])

# Random non-trained DenseNet
# batch sizes: 512 vs 1
tensor([4.8725e-06, 4.9665e-06, 4.8112e-06, 4.9490e-06, 4.7400e-06])

# ImageNet Pre-trained ResNet50 with new non-trained embedding layer
# batch sizes: 512 vs 1
tensor([1.7733e-05, 1.3676e-05, 1.5779e-05, 1.1592e-05, 1.1086e-05])

Based on these, it seems to be changing with batch size, and hence most probably related to batchnorm. PyTorch’s batchnorm layer uses an epsilon value of 1e-5, and the L1 distances are in the same range. Maybe related to this?

These differences are most likely created by the limited floating point precision, not by batchnorm or any other layers.
Usually you would expect a relative error of ~1e-6, but if you are using a deep model, these errors might accumulate.

Initially, I also thought so. After the above experiments, I think there is a problem somewhere (probably due to floating point precision), and most probably the problem is in the batchnorm layer, since the distance is proportional to the difference in batch size, and the distance is zero when the batch sizes are the same (see 512 vs 512).

How did you narrow down the batchnorm layers? Did you replace them and got a zero absolute error or what’s the reasoning?
Due to the limited floating point precision operations such as e.g. sum will also suffer from this effect:

x = torch.randn(100, 100, 100)
sum1 = x.sum()
sum2 = x.sum(0).sum(0).sum(0)
print((sum1 - sum2).abs())
> tensor(0.0016)

No, I have not done that. That is probably the best way to show it.

But, as I said: (1) it changes with batch size, (2) I get 0 distance if I use the same batch size.

In your example, you are using different sequence of operations, in contrast to my example, in which the network is supposed to be the same (same sequence of operations).