Why limit batchnorm to 3d?

In the source code for batch norm (lines 110 and 283), it checks for dimensions between 2 and 5, but this is already checked in the class definitions on the python source code. This isn’t a major issue, but then you can’t create an N-dimensional batch norm which would be useful in some cases. And the check is already mandatory in the Python source code. If there isn’t a reason why (does it break?), this check could easily be removed and just let python check the dimensions (as is already does), to allow implementation of 4D (or nD) batchnorm easily for example:

class BatchNorm4d(torch.nn.modules.batchnorm._BatchNorm):
    def __init__(self, num_features, **kwargs):
        super().__init__(num_features=num_features, **kwargs)
    def _check_input_dim(self, input):
        if input.dim() != 6:
            raise ValueError("expected 6D input (got {}D input)".format(input.dim()))

This is how 1d->3d batchnorm is already implemented in the python section.