How to absorb batch norm layer weights into Convolution layer weights?

I am trying to implement Split Brain Auto-encoder in pytorch. In their implementation first they pre train 2 networks after splitting across channel dimensions then after combining the channels and absorbing Batch Norm layer weights into Convolution layer weights. Then finally perform Semantic segmentation task. Paper Reference (Implementation is in Appendix, Page 9)

I am not able to understand the significance of absorbing BatchNorm and if there is any significance how to implement in pytorch. My initial network is:

class AlexNet_BN(nn.Module):

def __init__(self, in_channel=3,out_channel=3, layers=[96,256,384,384,256],out_size=180):
    super(AlexNet_BN, self).__init__()

    self.out_size = out_size

    self.conv1 = nn.Sequential(
        nn.Conv2d(in_channel, layers[0], kernel_size=11, stride=4, padding=2),#padding 5
        nn.BatchNorm2d(layers[0]),
        nn.ReLU(inplace=True)
    )
    self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2,padding=0)#padding 1

    self.conv2 = nn.Sequential(            
        nn.Conv2d(layers[0], layers[1], kernel_size=5,stride=1, padding=2),
        nn.BatchNorm2d(layers[1]),
        nn.ReLU(inplace=True)
    )
    self.pool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)#padding 1

    self.conv3 = nn.Sequential(            
        nn.Conv2d(layers[1], layers[2], kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(layers[2]),
        nn.ReLU(inplace=True)
    )
    self.conv4 = nn.Sequential(   
        nn.Conv2d(layers[2], layers[3], kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(layers[3]),
        nn.ReLU(inplace=True)
    )
    self.conv5 = nn.Sequential(   
        nn.Conv2d(layers[3], layers[4], kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(layers[4]),
        nn.ReLU(inplace=True)
    )
    self.pool5 = nn.MaxPool2d(kernel_size=3, stride=1)#padding 1 and stride 1
   

def forward(self, x):
    
    x = self.conv1(x)
    x = self.pool1(x)

    x = self.conv2(x)
    x = self.pool2(x)

    x = self.conv3(x)

    x = self.conv4(x)
    
    x = self.conv5(x)
    x = self.pool5(x)
    return x

Appreciate, if some one can help me in this regard.

Thanks

1 Like

The only sentence about “absorbing” the BatchNorm layer I could find is this one

We remove LRN layers and add BatchNorm layers after every convolution layer. After pre-training, we remove BatchNorm layers by absorbing the parameters into the preceding conv layers.

I don’t know, what “absorbing” means, but maybe they simply remove the BatchNorm layers and add its parameters to the Conv layers?
In my opinion, this makes only sense, if you set bias=False in your Conv layers, which can be done anyway, if the BatchNorm layer is used with affine=True.
In this approach, you could scale the conv weights with the batchnorm weights and just add the batchnorm bias. I’m not sure, how to deal with the running stats (mean and var).

However, I don’t see any reason to do this. Did you find any more information in the paper?
Also, is there a reference implementation?

Thanks @ptrblck for your response. My guess would be only comparison with same architecture to FCN-AlexNet, as default FCN-AlexNet architecture does not has batch norm.
How about freeze the BatchNorm layer while Tuning?

Yes, Authors provide implementation for this on github but I don’t have much idea of Caffe Blob objects.

There was a PR in pytorch doing that (running stats can also be absorbed in the previous layer). https://github.com/pytorch/pytorch/pull/901. It can be done once the network is trained and parameters no longer change. Other than for performance, it does not matter whether batch norm is absorbed or run as a separate layer, results are the same.

2 Likes

The following example should work. The output may has some difference in the magnitude of e-9. This is caused by the precision of the float and I think it is normal.


import torch
import torch.nn as nn

# the fuse function code is refered from https://zhuanlan.zhihu.com/p/49329030
def fuse(conv, bn):
    w = conv.weight
    mean = bn.running_mean
    var_sqrt = torch.sqrt(bn.running_var + bn.eps)
    beta = bn.weight
    gamma = bn.bias
    if conv.bias is not None:
        b = conv.bias
    else:
        b = mean.new_zeros(mean.shape)
    w = w * (beta / var_sqrt).reshape([conv.out_channels, 1, 1, 1])
    b = (b - mean)/var_sqrt * beta + gamma
    fused_conv = nn.Conv2d(conv.in_channels,
                         conv.out_channels,
                         conv.kernel_size,
                         conv.stride,
                         conv.padding,
                         bias=True)
    fused_conv.weight = nn.Parameter(w)
    fused_conv.bias = nn.Parameter(b)
    return fused_conv

class DummyModule(nn.Module):
    def __init__(self):
        super(DummyModule, self).__init__()        
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1, bias = True)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1, bias = True)
        self.bn2 = nn.BatchNorm2d(16)
        self.relu2 = nn.ReLU(inplace=True)        
        self.fuse1 = fuse(self.conv1, self.bn1)
        self.fuse2 = fuse(self.conv2, self.bn2)
    def forward(self, x, fusion = False):
        if fusion:
            x= self.fuse1(x)
        else:
            x = self.bn1(self.conv1(x))
        x = self.relu1(x)
        if fusion:
            x= self.fuse2(x)
        else:
            x = self.bn2(self.conv2(x))
        x = self.relu2(x)
        return x

def test_net():    
    model = DummyModule()
    model.eval()    
    p = torch.randn([1, 3, 224, 224])
    import time
    s = time.time()
    o_output = model(p)
    print("Original time: ", time.time() - s)
    s = time.time()
    f_output = model(p, True)    
    print("Fused time: ", time.time() - s)
    print("Max abs diff: ", (o_output - f_output).abs().max().item())
    #assert(o_output.argmax() == f_output.argmax())
    print("MSE diff: ", nn.MSELoss()(o_output, f_output).item())
test_net()
1 Like

Is there any performance gain here? anything noticeable?

I’ve tried manually merging batchnorm and conv layers. A sample of code:

def merge_bn_conv(bn_layer, conv_layer):
    weight = (bn_layer.weight.detach()*conv_layer.weight.permute(1,2,3,0)/(1e-10 + bn_layer.running_var.detach().sqrt())).permute(3,0,1,2)
    conv_layer.weight = nn.Parameter(weight)
    bias = bn_layer.bias.detach() - bn_layer.running_mean.detach()*bn_layer.weight.detach()/(1e-10 + bn_layer.running_var.detach().sqrt())
    conv_layer.bias = nn.Parameter(bias)

In modules:

def drop_batchnorm():
    merge_bn_conv(self.bn1, self.conv1)
    del self.bn1
    ...
    self.other_module.drop_batchnorm()
    ...
    self.batchnorm_dropped = True

In .forward:

def forward(self, input):
    ...
    out = self.conv1(out)
    if not self.batchnorm_dropped:
        out = self.bn1(out)
    ...

I compared the same model in eval mode before and after calling drop_batchnorm.
Surprisingly, after drop_batchnorm the throughput of the model dropped by around 10%…

I could understand if there was no difference (pytorch doing this optimization for you), but how can this optimization actually make things worse?
Would be cool if someone clarified this (and double checked that the effect exists).