How does pytorch’s batch norm know if the forward pass its doing is for inference or training?

how does pytorch’s batch norm know if the forward pass its doing is for inference or training? I am evaluating the the test performance of my net but realized that I’m not sure how my net knows if its training phase or inference phase. How does pytorch handle this?

1 Like


In the source code here,
the function F.batch_norm has the parameter,

when you train model, you use model.train()
when you test, you use model.eval()

the model.train() tells the to be True.

1 Like

I see. I tried running net.train() when there is no batch norm and there was no error thrown…I guess I need to re-do everything cuz I wasn’t using that and I was evalauting test and train at every epoch…so I need to switch around that flag quite often

what are the consequences of having the train mode for both train and inference?

It will not be both train and inference.

When use model.train(), the become True, it will keep the value until you use model.eval()

then you use model.eval(), the become False

1 Like

being in train when evaluating for test data…is what I meant

Sorry , I am not exactly understand what you mean.

Do you mean that you do training and test alternatively?

What I mean is that I do net.train() and then I do the forward pass on both test and train images. I only train in train images but I didn’t change the flag on test images by accident…

Btw, do you know when are the mean variance for example updated? What happens to them during inference mode? Are they just fixed?

I see.

As far as I know, the flag just make a difference in the Dropout and Batch Norm layers,
so if you don’t use these two types of layer, it may don’t have influence.

but I am using batch norm :joy: oooops!

It will not keep the mean, variance fixed when is true,

but the correct way if keep them fixed when evaluating.

I think in the training mode, the batch mean & var are used, the running mean & var are updated but never actually used

1 Like