Hi,
Suppose there is a case like this:
x1, x2, lb1, lb2 = dataloader.next
logits1 = model(x1)
logits2 = model(x2)
loss = criteria1(logits1, lb1) + criteria2(logits2, lb2)
...
The problem is that I need to update bn moving average status only on x1
, and I do not need to update the moving average status on other forward computations. How could I do this with pytorch?
You could call .eval()
on all batch norm layers in your model after passing x1
to the model and before using x2
. After it just reset these layers to .train()
again.
If you don’t use any dropout layers, you could also just call model.eval()/.train()
.
thanks for replying!!
If I call model.eval()
, the input tensor would be normalized with the running mean and running var rather than the batch statistics, can this be avoided ?
You could disable the running stats completely by setting track_running_stats=False
.
However, your use case seems to be:
- for
x1
use batch stats and update running estimates
- for
x2
just use batch stats
What about the affine parameters?
Should they also be updated using the loss from x2
or just x1
?
In the latter case, you could use two different batchnorm layers, pass a flag to forward
, and switch between these layers depending if x1
or x2
was passed.
Thanks for replying!!
The affine parameters are trained from both x1 and x2. Two different batchnorm layers can solve the running estimates problem, but the affine parameters are not shared between these two batchnorm layers in this way. Any suggestions ?
You could use a hacky way of setting track_running_stats=False
for the x2
input and reset the running stats manually.
Here is a small example:
bn = nn.BatchNorm2d(3, track_running_stats=True)
bn.eval()
print(bn.running_mean) # zeros
print(bn.running_var) # ones
x = torch.randn(2, 3, 4, 4)
out = bn(x)
print(out.mean()) # should NOT be perfectly normal, since running stats used
print(out.std())
print(bn.running_mean) # not updated
print(bn.running_var)
# Disable running_states
# internally buffers are still valid, so we need to reset them manually
bn.track_running_stats = False
out2 = bn(x)
print(out2.mean()) # should be normal now
print(out2.std())
print(bn.running_mean) # unfortunately updated
print(bn.running_var)
# Reset running stats
with torch.no_grad():
bn.running_mean = (bn.running_mean - x.mean([0, 2, 3]) * bn.momentum) / (1 - bn.momentum)
bn.running_var = (bn.running_var - x.var([0, 2, 3]) * bn.momentum) / (1 - bn.momentum)
print(bn.running_mean) # back to values before last update
print(bn.running_var)
# enable running stats again
bn.track_running_stats = True
out3 = bn(x)
print(out3.mean() == out.mean()) # compare to initial output
print(out3.std() == out.std())
I would recommend to test this approach in your model and make sure you’ll get the desired outputs, gradients and updates.
Thanks a lot, I would try.