Pytorch documentation for Batchnorm

For my current use case, I would like BatchNorm to behave as though it is in inference mode and not training (just BatchNorm and not the whole network).

I notice from Pytorch documentation that

track_running_stats: a boolean value that when set to ``True``, this
            module tracks the running mean and variance, and when set to ``False``,
            this module does not track such statistics and always uses batch
            statistics in both training and eval modes. Default: ``True`` 

Now I was under the impression that BatchNorm at test time/eval time would not compute anything on the fly but use previously stored mean/variance during training. The running stats flag mentions that even if set to False, it will use batch statistics in test mode.

When I see this:
http://cs231n.stanford.edu/slides/2019/cs231n_2019_lecture07.pdf

it mentions that at test time BatchNorm can be fused because there is no separate calculation performed at test time. (from the slides quote: "during testing batchnorm becomes a linear operator! Can be fused with the previous fully-connected or conv layer "). So if I do this in Pytorch

def forward(self, x):
.
.
.
self.bn = BatchNorm2d(2, track_running_stats=False)
.
.

and after instantiating the net object, I do

net.bn.eval(), 

will this calculate batch mean and variance when an input batch is applied because the documentation says so? How do I not calculate stats in test mode? Basically how to ensure that in eval mode, it uses only previous stats and does not use anything from the current batch?

Leave track_running_stats=True and set the batchnorm layer to eval().
This setup will update the running stats in train() mode and just use them (without updating) in eval() mode.

If you set track_running_stats=False, the batch statistics will always be used as explained in the docs.

Thanks a lot for your response!

Actually the reason I asked this question (which I should included above) is because I was not observing what you mentioned. So let me put forth my example used:

I have a simple model that looks like this:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3,2,kernel_size=3, bias=False)
        self.bn = nn.BatchNorm2d(2, track_running_stats=True)
        self.fc1 = nn.Linear(18, 10, bias=False)

    def forward(self, x):
        x1 = self.conv1(x)
        print("conv output")
        print(x1)
        # amean = np.mean(x1.data.numpy(), axis=(0,2,3))
        # print(amean)
        # astd = np.std(x1.data.numpy(), axis=(0,2,3))
        # print(astd)
        # x11 = x1.data.numpy()
        # for i in range(2):
        #     x11[:,i,:,:] = (x1.data.numpy()[:,i,:,:] - amean[i]) / astd[i]
        # print(x11)
        x2 = self.bn(x1)
        print("bn output")
        print(x2)
        x3 = F.relu(x2)
        print("relu1 output")
        print(x3)
        x4 = x3.view(x3.size(0), -1)
        x5 = self.fc1(x4)
        print("fc output")
        print(x5)
        x6 = F.relu(x5)
        print("relu2 output")
        print(x6)
        return x6

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

torch.manual_seed(0)
net = Net()
net.bn.bias.requires_grad=False
net.bn.eval()

Now I pass in one training sample from the forward function if I check the output of bn, this is what it prints:

bn output
tensor([[[[-1.6799, -0.0496,  0.2127],
          [ 0.3922,  1.3428,  0.0598],
          [ 0.3674,  0.3883, -1.0339]],

         [[ 0.4344,  0.1697, -0.1216],
          [ 0.1510, -0.0127, -0.1253],
          [-0.0596, -0.1495, -0.2863]]]], grad_fn=<NativeBatchNormBackward>)

and output of the previous layer (conv) is:

 tensor([[[[-0.0403,  0.0103,  0.0185],
          [ 0.0240,  0.0535,  0.0137],
          [ 0.0233,  0.0239, -0.0202]],

         [[-0.1044, -0.1664, -0.2347],
          [-0.1708, -0.2092, -0.2356],
          [-0.2202, -0.2412, -0.2733]]]], grad_fn=<MkldnnConvolutionBackward>)

if eval did not calculate anything from the current batch, then the output of bn should not have been -1.6799 at the first index, for example. The -1.6799 is obtained when normalizing the output of conv. In the forward function the commented out code shows the formula used to obtain the same result. If I check the first index of x11 it also is -1.6799.

Could you please explain why I see this? If I replace eval() with train() it is the same result. This is the one and only training sample used (one batch) and if it did not use current batch stats, it should have some other different output I presumed because no previous mean and variance was calculated as this is the first time the model is provided with an input. But it looks like even using eval(), it is using the current batch to make the output.

Thanks!

Note that batchnorm layers have also affine parameters by default (affine=True).
While the weight and bias are initialized with zeros and ones, respectively, in the current master, the weight parameter was initialized with a uniform distribution up to PyTorch 1.1.0.

If you are not using a nightly build, you might add this to your code:

torch.manual_seed(0)
net = Net()
net.bn.bias.requires_grad=False
with torch.no_grad():
    net.bn.weight.fill_(1.)
net.bn.eval()

Thanks.

But I think it it still doing the same thing. It is using the current batch in eval mode to calculate mean and variance.
So basically while the weight is initialized differently now, it still uses the current batch to make the bn output
If you check the output of conv listed above and calculate the per channel mean and variance and do the following (based on my code above):

x11[:,0,:,:]*self.bn.weight.data.numpy()[0]
x11[:,1,:,:]*self.bn.weight.data.numpy()[1]

x11 is the normalized conv output from the applied sample input.
The result above will match the first index for example of the output of bn printed by Pytorch!

The issue is basically the current batch is used to make the output of bn in eval mode.
I would expect to see the default mean and variance (if one is not present) used not including the current batch. But it uses this one batch to make the calculation

Could you check the output again?
Using your model and this code snippet:

torch.manual_seed(0)
net = Net()
net.bn.bias.requires_grad=False
net.bn.weight.requires_grad=False
with torch.no_grad():
    net.bn.weight.fill_(1.)
net.bn.eval()

x = torch.randn(1, 3, 5, 5)
net(x)

yields the exactly same outputs:

conv output
tensor([[[[ 0.4132035971, -0.7129006982,  0.0829921290],
          [ 0.5575969219, -0.5301585793, -0.2478107512],
          [ 0.2830447555, -0.8774284720, -0.7162327170]],

         [[-0.1680016071, -0.6251038313, -0.3634709716],
          [ 0.5589009523,  0.4439742863,  0.1032543331],
          [ 0.4814728498,  0.3313015699, -0.5421049595]]]],
       grad_fn=<MkldnnConvolutionBackward>)
bn output
tensor([[[[ 0.4132015407, -0.7128971219,  0.0829917118],
          [ 0.5575941205, -0.5301558971, -0.2478095144],
          [ 0.2830433249, -0.8774240613, -0.7162291408]],

         [[-0.1680007726, -0.6251006722, -0.3634691536],
          [ 0.5588981509,  0.4439720511,  0.1032538190],
          [ 0.4814704359,  0.3312999010, -0.5421022177]]]],
       grad_fn=<NativeBatchNormBackward>)
...

Sure, I can take a look. What is the mean and variance by default for the calculation? Is it set to zero in this case by default?

running_mean is initialized as zeros, while running_var as ones (in 1.1.0 and the current master).

1 Like

I just found the issue!!

In my code I additionally do,

x = torch.randn(1, 3, 5, 5)
torch.onnx.export(net, x, ‘ayeonnx.onnx’)
net(x)

to convert the model to onnx (for some use case ) before executing net(x)

and this messes up the whole thing!

if I instead do

torch.onnx.export(copy.deepcopy(net), x, 'ayeonnx.onnx')

I get what you get. Is this not a bug in that I had to deepcopy and for some reason the net object got affected without that or is this expected?

1 Like

Yeah, that’s probably the issue and might be considered a bug.

As you can see here, set_training changes the current training mode of the model to the passed argument mode.
By default torch.onnx.export uses training=False, which should be fine.

However, since you are not setting the complete model to eval, net.training will still return True:

net.bn.eval()
print(net.training)
> True
print(net.bn.training)
> False

While this is your desired use case, set_training only checks the training attribute of the parent model and sets the complete model to its “old” mode again:

torch.onnx.export(net, x, 'tmp.onnx')
print(net.training)
> True
print(net.bn.training)
> True

This will of course cause the next forward call to update the running statistics, so you should set the batchnorm layer to eval again after exporting the model using onnx.

The proper approach would maybe be to restore the training attribute for each submodule recursively, but I’m not sure if that’s an edge case.
Anyway, feel free to open an issue and link to this topic so that this can be discussed with the ONNX devs.

1 Like

thanks! really appreciate your consistent help to finally isolate the issue.