Hello,
I am trying to convert some code that involves conditional batch normalization from Tensorflow to Pytorch. In TF, you can call tf.nn.batch_normalization() which accepts the input, mean, variance, scale, and shift (gamma and beta). However, in Pytorch if I call torch.nn.functional.batch_norm, it does not have any parameters for mean, variance, gamma, and beta. Or have I got this wrong?
Here is the parameter list:
return F.batch_norm(
input,
# If buffers are not to be tracked, ensure that they won't be updated
self.running_mean
if not self.training or self.track_running_stats
else None,
self.running_var if not self.training or self.track_running_stats else None,
self.weight,
self.bias,
bn_training,
exponential_average_factor,
self.eps,
)
BTW, I don’t want to use BatchNorm2 because it does not have an option to pass the above-mentioned values.