How does pytorch’s batch norm know if the forward pass its doing is for inference or training?

how does pytorch’s batch norm know if the forward pass its doing is for inference or training? I am evaluating the the test performance of my net but realized that I’m not sure how my net knows if its training phase or inference phase. How does pytorch handle this?

1 Like

Hi,

In the source code here,
the function F.batch_norm has the parameter self.training,

when you train model, you use model.train()
when you test, you use model.eval()

the model.train() tells the self.training to be True.

1 Like

I see. I tried running net.train() when there is no batch norm and there was no error thrown…I guess I need to re-do everything cuz I wasn’t using that and I was evalauting test and train at every epoch…so I need to switch around that flag quite often

what are the consequences of having the train mode for both train and inference?

It will not be both train and inference.

When use model.train(), the self.training become True, it will keep the self.training value until you use model.eval()

then you use model.eval(), the self.training become False

1 Like

being in train when evaluating for test data…is what I meant

Sorry , I am not exactly understand what you mean.

Do you mean that you do training and test alternatively?

What I mean is that I do net.train() and then I do the forward pass on both test and train images. I only train in train images but I didn’t change the flag on test images by accident…

Btw, do you know when are the mean variance for example updated? What happens to them during inference mode? Are they just fixed?

I see.

As far as I know, the flag self.training just make a difference in the Dropout and Batch Norm layers,
so if you don’t use these two types of layer, it may don’t have influence.

but I am using batch norm :joy: oooops!

It will not keep the mean, variance fixed when self.training is true,

but the correct way if keep them fixed when evaluating.

I think in the training mode, the batch mean & var are used, the running mean & var are updated but never actually used

1 Like