Higher-Dimensional Batch Norm

Hi,

I have a batch of k-dimensional data, meaning I have tensors of size (batch_size, n1, n2, ..., nk).
I would like to normalize the data after the batch axis using a batch norm module, meaning I would like to learn the k-dimensional mean of my data, subtract it from my tensors along the batch axis, and learn the centered L2-norm of my data, and scale my tensors just like a batch norm does.
It would be nice to not need to flatten my data and then reshape it after the batch axis all the time.

Is there a pytorch functional class that implements this functionality such that I can implement an appropriate batch norm module myself?
It would be nice to just have a single batch norm module and pass it an axis (axes) along which to operate, instead of having the current zoo of batch norms…

Thanks!

Any chance SyncBatchNorm could help you? Note that one is for tensors of size (batch_size, channels, n1, n2, ...) so it has the extra channels dimension vs your example, but you could just unsqueeze a length-1 dimension in there, and it would work for your use case.

Though I’ve never used it, reading through the description seems promising.

1 Like

Thanks, this looks promising!