Parameters difference between nn.BatchNorm2d() and F.batch_norm()

Dear all:
I found the api parameters of nn.BatchNorm2d() and F.batch_norm() is different as :

torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

torch.nn.functional.batch_norm(input, running_mean, running_var, weight=None, bias=None, training=False, momentum=0.1, eps=1e-05)

To be concise, in nn.BatchNorm2d, it can set affine=False to not optimize w,b and set track_running_stats=False to only use current batch statistics.
However, There is ONLY one parameter training which I dont know the exact behavior when setting it to True or False with respect to ‘w, b’ and running_mean, running_var??


I think the training parameter tells the BatchNorm function how to behave, since it should behave different when running inference on the model. The running_mean and running_var parameters are the current mean and variance of the data, that will be updated using:

running_mean = momemtum * running_mean + (1.0 - momentum)  *  batch_mean
running_var = momemtum * running_var + (1.0 - momentum) *  batch_var

I think the reason why the torch.nn.BatchNorm2d does not need these parameters is because since it is a part the model it can keep track of the data the model has “seen” and wheter it is running in training or inference mode.

P.S: I am by no means an expert so take my explanation with a big grain of salt. Cheers