I am trying to convert some code that involves conditional batch normalization from Tensorflow to Pytorch. In TF, you can call tf.nn.batch_normalization() which accepts the input, mean, variance, scale, and shift (gamma and beta). However, in Pytorch if I call torch.nn.functional.batch_norm, it does not have any parameters for mean, variance, gamma, and beta. Or have I got this wrong?
Here is the parameter list:
return F.batch_norm( input, # If buffers are not to be tracked, ensure that they won't be updated self.running_mean if not self.training or self.track_running_stats else None, self.running_var if not self.training or self.track_running_stats else None, self.weight, self.bias, bn_training, exponential_average_factor, self.eps, )
BTW, I don’t want to use BatchNorm2 because it does not have an option to pass the above-mentioned values.