How to pass gamma and beta to batch_norm in torch?


I am trying to convert some code that involves conditional batch normalization from Tensorflow to Pytorch. In TF, you can call tf.nn.batch_normalization() which accepts the input, mean, variance, scale, and shift (gamma and beta). However, in Pytorch if I call torch.nn.functional.batch_norm, it does not have any parameters for mean, variance, gamma, and beta. Or have I got this wrong?

Here is the parameter list:

return F.batch_norm(
            # If buffers are not to be tracked, ensure that they won't be updated
            if not or self.track_running_stats
            else None,
            self.running_var if not or self.track_running_stats else None,

BTW, I don’t want to use BatchNorm2 because it does not have an option to pass the above-mentioned values.

The mean and var will be calculated from the input activation and I don’t know why you want to calculate it beforehand.
Every other parameter is accepted as an argument where weight and bias correspond to gamma and beta, respectively.

I see, I thought I should calculate them myself. Thank you for your reply @ptrblck!