I try to re-produce the forward pass of
nn.BatchNorm2d in training mode as follows:
bn = nn.BatchNorm2d(16) data = 2.2+torch.randn(32,16,10,10).float() out_bn = bn(data) my_out_bn = (data - data.mean([0,2,3],keepdim=True)) / torch.sqrt(data.var([0,2,3],keepdim=True) + bn.eps) my_out_bn = bn.weight.reshape(1,-1,1,1) * my_out_bn + bn.bias.reshape(1,-1,1,1) print((out_bn - my_out_bn).abs().max())
I find the maximum absolute error could go as high as 7e-4, as it seems not reduce with double precision. Is there anyhting wrong in my own implementation? Thanks!