BatchNorm 4D (or N-dim)

Does BN care about input dimensionality? The 1d vs 2d vs 3d just assert input dims, all else is same. Is below valid for the N-dim (or at least 4D) case?

class BatchNormNd(nn.modules.batchnorm._BatchNorm):
    def _check_input_dim(self, input):
        pass

It should be valid as the underlying functionality is general for higher dimensions because the underlying data layout is unchanged; e.g.,

>>> import torch
>>> a = torch.randn(3, 65, 9, 9)
>>> bn2 = torch.nn.modules.BatchNorm2d(65)
>>> bn1 = torch.nn.modules.BatchNorm1d(65)
>>> out2 = bn2(a)
>>> out1 = bn1(a.reshape(3, 65, 81)).reshape(3, 65, 9, 9)
>>> torch.allclose(out2, out1)
True

I verified this much but I wonder then what’s the point of having 1d etc instead of just one BatchNormNd? An API thing maybe, but as long as forward and backward passes are same within a network, I’m fine

1 Like

For one, there’s

RuntimeError: Expected 2 to 5 dimensions, but got 6-dimensional tensor
for argument #1 'input' (while checking arguments for cudnn_batch_norm)

which is fixable by view, but unsure what motivates the exception.