Getting parameters of `torch.nn.BatchNorm2d` during training

Hi all, I’m trying to figure out how exactly a 2d BN layer applies on an input tensor. Let me give a toy example:

import torch                                                                                                                                                                          
import torch.nn as nn 

num_features = 3                                                                                                                                                                      
X = torch.randn(2, num_features, 4, 4) 
bn = nn.BatchNorm2d(num_features)

Y = bn(X)

What I would like to do is to “reproduce” the output of the BN layer, Y, not by passing X through bn(), but by using the formula shown here.

That is, I would something like:

Z = * torch.div(X - bn.yyy, torch.sqrt(bn.zzz + bn.eps)) + bn.ppp

I’m looking for the correct attributes of bn (xxx, yyy, zzz, and ppp) - or, even a yet better way to do that, since -during training- I want to get the parameters of BN layer and apply them to another tensor.

Thanks a lot.

If I understood correctly, you want to incorporate the batch norm parameters into e.g. a convolution, then you can do e.g. this:

    def incorporateBatchNorm(self, bn):

        gamma = bn.weight
        beta = bn.bias
        mean = bn.running_mean
        var = bn.running_var
        eps = bn.eps

        var_sqrt = torch.sqrt(var + eps)

        w = (self.weight * gamma.reshape(self.out_channels, 1, 1, 1)) / var_sqrt.reshape(self.out_channels, 1,
                                                                                         1, 1)
        b = ((self.bias - mean) * gamma) / var_sqrt + beta

        self.weight = nn.Parameter(w)
        self.bias = nn.Parameter(b)

And then you can just do the forward pass of the convolution layer.

Hi @sami, thank you very much for your answer. I understand what you did above, and even though I don’t want to use it in a convolutional (or any other parametric operation), your approach should do the job for me. But it doen’t…

Let’s say that you have a 4-dimensional tensor X of size (bs, ch, d, d):

bs = 2, ch=3, d=4
X = torch.randn(bs, ch, d, d)

Then, Y = bn(X) will forward X through the BN layer, by computing bn.running_mean, bn.running_var, etc…

I would like to reproduce the value of Y by manually applying the operation of BN. That should be the following:

Y_ = bn.weight.reshape(1, ch, 1, 1) * (X - bn.running_mean.weight.reshape(1, ch, 1, 1))/torch.sqrt(bn.running_var.weight.reshape(1, ch, 1, 1) + bn.eps) + bn.bias.weight.reshape(1, ch, 1, 1)

But it doesn’t seem to work. I tried also with a BN layer without any affine transformation, but it still cannot reproduce Y.

Hi, @nullgeppetto, I’m facing the same problem as yours. I wonder if you have found the answer?

Hi @ZacharyGong, the truth is that abandoned the idea quickly so I didn’t try to figure it out. Although I would still want to sort it out. I may take another look soon and let you know, but I think that we both miss something here (it’s pretty much straightforward after all). Maybe @sami wants to take a look also (sorry for the spam!).

t = image
t = F.conv2d(t, network.conv1.weight)
w = torch.ones(t.size()) * network.bn1.weight.reshape(1,-1,1,1)
b = torch.ones(t.size()) * network.bn1.bias.reshape(1,-1,1,1)
m = torch.ones(t.size()) * network.bn1.running_mean.reshape(1,-1,1,1)
v = torch.ones(t.size()) * network.bn1.running_var.reshape(1,-1,1,1)
v = v.sqrt()
t = (t - m) / v * w + b

The code above got exact the same results as

self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, bias=False)
self.bn1 = nn.BatchNorm2d(6)

def forward(self, t):
t = self.bn1(self.conv1(t))

in network.eval() mode. @ZacharyGong @nullgeppetto Hope it helps.