I am using batch norm in my network. For each epoch in training, I will perform evaluate loss in validation set. Will the batch norm statistic in train() mode maintain in next epoch because I want to train these statistic? Thanks
for epoch in range of (100):
net.train()
loss=...
loss.backward()
optimizer.step()
#Evaluate in validation set
with torch.no_grad():
for val in valloader:
images, targets= val
net.eval()
...
The running statistics will be updated once your model is set to .train()
again.
Your code snippet looks fine. You could move the net.eval()
before the loop through your validation set, but it’s not a problem if you call .eval()
repeatedly.
1 Like
Thanks. So the running statistic will be stored somewhere when I call .eval(). Then if I call .train(), the information of running statistic will be recovered from somewhere
to update. Am I right?
They will just not be updated. The running stats are already stored in bn.running_mean
and bn.running_var
. If you set this layer to eval, the running stats will just be applied withour updating them. Have a look at this small example:
bn = nn.BatchNorm2d(3)
x = torch.randn(10, 3, 24, 24)
# Print initial stats
print(bn.running_mean, bn.running_var)
# Update once and print stats
output = bn(x)
print(bn.running_mean, bn.running_var)
# Set to eval; the stats should stay the same
bn.eval()
output = bn(x)
print(bn.running_mean, bn.running_var)
# Set to train again; the stats should be changed now
bn.train()
output = bn(x)
print(bn.running_mean, bn.running_var)
1 Like