Here is a question about fuse torch.nn.BatchNorm2d

This a function about fuse torch.nn.BatchNorm2d by myself :face_with_head_bandage:

import torch.nn as nn

class FuseBN(nn.Module):
    def __init__(self, layer):
        super().__init__()
        eps = layer.eps
        mean = layer.running_mean
        var = layer.running_var
        weight = layer.weight
        bias = layer.bias

        bias = bias - (weight*mean)/torch.sqrt(var + eps)
        weight = weight / torch.sqrt(var + eps)
        
        self.weight  = weight.reshape(1, -1, 1, 1)
        self.bias = bias.reshape(1, -1, 1, 1)
       
    def forward(self, x):
        out = self.weight * x + self.bias
        return out

When I try to compare the results of BN and FuseBN

import torch
import torch.nn as nn

data = torch.randn(1, 3, 224, 224)
bn = nn.BatchNorm2d(3)
fuse_bn = FuseBN(bn)

bn_result = bn(data)
fuse_bn_result = fuse_bn(data)

compare_value = torch.max(torch.abs(bn_result - fuse_bn_result))

In theory, the difference should be about 1e-5, but I get the compare_value is 0.143. :thinking:

I don’t know, why? please help me, thanks a lot. :handshake:

See batchnorm2d
when you forward data from the batch norm layer, you are changing running averages.
Use bn.eval() to freeze these buffers.

You are comparing the native batchnorm layer in training mode with your FuseBN layer, which uses the eval logic.
Also, after initializing the batchnorm layer the running mean would be all zeros and running_var all ones so you might want to train it for a few steps so that both layers would indeed normalize the data with running stats.
This should work:

data = torch.randn(1, 3, 224, 224) * 10 + 5
bn = nn.BatchNorm2d(3)

for _ in range(100):
    out = bn(data)

print(bn.running_mean)
print(bn.running_var)
bn.eval()

fuse_bn = FuseBN(bn)
bn_result = bn(data)
fuse_bn_result = fuse_bn(data)

compare_value = torch.max(torch.abs(bn_result - fuse_bn_result))
print(compare_value)
> tensor(4.7684e-07, grad_fn=<MaxBackward1>)

PS: @mMagmer was a bit faster, but just posting it for the sake of completeness with the code.

1 Like

awsome. thanks @ptrblck @mMagmer