Batch Normalization disambiguation

No, in this case you would use L channels.

However, the other approache (nn.BatchNorm2d vs. nn.BatchNorm1d) should yield the same result. Since you are using the affine batchnorm transformation, you would have to make sure the weight parameter is set to equal values (bias should be all zeros in both cases anyway).

N, C, H, W = 10, 3, 24, 24
x = torch.randn(N, C, H, W)

bn2d = nn.BatchNorm2d(3)
bn1d = nn.BatchNorm1d(3)

with torch.no_grad():
    bn2d.weight = bn1d.weight
    bn2d.bias = bn1d.bias


output2d = bn2d(x)
output1d = bn1d(x.view(N, C, -1))
print((output2d.view(N, C, -1) == output1d).all())
> tensor(1, dtype=torch.uint8)

Alternatively, you could set affine=False and might skip the parameter assignment.

For completeness: this PR should change the initialization of the affine parameters, such that weight will be initialized with ones.

1 Like