Any information on torch.nn.functional.batch_norm/torch.batch_norm


could anyone point me towards more information with regard to the torch.nn.functional.batch_norm method?

When I checked the Python source code I was only able to get up to the point of

return F.batch_norm(
            input, self.running_mean, self.running_var, self.weight, self.bias,
   or not self.track_running_stats,
            exponential_average_factor, self.eps)

in the torch/nn/modules/batchnorm (in the PyTorch source code).

My question is over which dimension does functional.batch_norm compute the batch_statistics?
Judging from the code I would suspect the 2nd dimension.
The same method functional.batch_norm is used for BatchNorm1D, BatchNorm2D and BatchNorm3D.
So for BatchNorm1D with [batch_size, features] data tensors the second dimension is obviously the relevant one for batch_norm and for BatchNorm2D with [batch_size, channels, height, width] the second dimension is the relevant one again.

Thanks in addvance.