I try to re-produce the forward pass of nn.BatchNorm2d
in training mode as follows:
bn = nn.BatchNorm2d(16)
data = 2.2+torch.randn(32,16,10,10).float()
out_bn = bn(data)
my_out_bn = (data - data.mean([0,2,3],keepdim=True)) / torch.sqrt(data.var([0,2,3],keepdim=True) + bn.eps)
my_out_bn = bn.weight.reshape(1,-1,1,1) * my_out_bn + bn.bias.reshape(1,-1,1,1)
print((out_bn - my_out_bn).abs().max())
I find the maximum absolute error could go as high as 7e-4, as it seems not reduce with double precision. Is there anyhting wrong in my own implementation? Thanks!