Does batch norm statistic maintain when switching train and val in training?

I am using batch norm in my network. For each epoch in training, I will perform evaluate loss in validation set. Will the batch norm statistic in train() mode maintain in next epoch because I want to train these statistic? Thanks

for epoch in range of (100):
     #Evaluate in validation set
     with torch.no_grad():
            for val in valloader:
                images, targets= val

The running statistics will be updated once your model is set to .train() again.
Your code snippet looks fine. You could move the net.eval() before the loop through your validation set, but it’s not a problem if you call .eval() repeatedly.

1 Like

Thanks. So the running statistic will be stored somewhere when I call .eval(). Then if I call .train(), the information of running statistic will be recovered from somewhere to update. Am I right?

They will just not be updated. The running stats are already stored in bn.running_mean and bn.running_var. If you set this layer to eval, the running stats will just be applied withour updating them. Have a look at this small example:

bn = nn.BatchNorm2d(3)
x = torch.randn(10, 3, 24, 24)

# Print initial stats
print(bn.running_mean, bn.running_var)

# Update once and print stats
output = bn(x)
print(bn.running_mean, bn.running_var)

# Set to eval; the stats should stay the same
output = bn(x)
print(bn.running_mean, bn.running_var)

# Set to train again; the stats should be changed now
output = bn(x)
print(bn.running_mean, bn.running_var)
1 Like