This a function about fuse torch.nn.BatchNorm2d by myself
import torch.nn as nn
class FuseBN(nn.Module):
def __init__(self, layer):
super().__init__()
eps = layer.eps
mean = layer.running_mean
var = layer.running_var
weight = layer.weight
bias = layer.bias
bias = bias - (weight*mean)/torch.sqrt(var + eps)
weight = weight / torch.sqrt(var + eps)
self.weight = weight.reshape(1, -1, 1, 1)
self.bias = bias.reshape(1, -1, 1, 1)
def forward(self, x):
out = self.weight * x + self.bias
return out
When I try to compare the results of BN and FuseBN
import torch
import torch.nn as nn
data = torch.randn(1, 3, 224, 224)
bn = nn.BatchNorm2d(3)
fuse_bn = FuseBN(bn)
bn_result = bn(data)
fuse_bn_result = fuse_bn(data)
compare_value = torch.max(torch.abs(bn_result - fuse_bn_result))
In theory, the difference should be about 1e-5, but I get the compare_value
is 0.143.
I don’t know, why? please help me, thanks a lot.