Hi,
I have a batch of k-dimensional data, meaning I have tensors of size (batch_size, n1, n2, ..., nk)
.
I would like to normalize the data after the batch axis using a batch norm module, meaning I would like to learn the k-dimensional mean of my data, subtract it from my tensors along the batch axis, and learn the centered L2-norm of my data, and scale my tensors just like a batch norm does.
It would be nice to not need to flatten my data and then reshape it after the batch axis all the time.
Is there a pytorch functional class that implements this functionality such that I can implement an appropriate batch norm module myself?
It would be nice to just have a single batch norm module and pass it an axis (axes) along which to operate, instead of having the current zoo of batch norms…
Thanks!