How could I fis bn moving average updates in some conditions but not fix it in others


Suppose there is a case like this:

x1, x2, lb1, lb2 =
logits1 = model(x1)
logits2 = model(x2)
loss = criteria1(logits1, lb1) + criteria2(logits2, lb2)

The problem is that I need to update bn moving average status only on x1, and I do not need to update the moving average status on other forward computations. How could I do this with pytorch?

You could call .eval() on all batch norm layers in your model after passing x1 to the model and before using x2. After it just reset these layers to .train() again.

If you don’t use any dropout layers, you could also just call model.eval()/.train().

thanks for replying!!

If I call model.eval(), the input tensor would be normalized with the running mean and running var rather than the batch statistics, can this be avoided ?

You could disable the running stats completely by setting track_running_stats=False.
However, your use case seems to be:

  • for x1 use batch stats and update running estimates
  • for x2 just use batch stats

What about the affine parameters?
Should they also be updated using the loss from x2 or just x1?
In the latter case, you could use two different batchnorm layers, pass a flag to forward, and switch between these layers depending if x1 or x2 was passed.

Thanks for replying!!

The affine parameters are trained from both x1 and x2. Two different batchnorm layers can solve the running estimates problem, but the affine parameters are not shared between these two batchnorm layers in this way. Any suggestions ?

You could use a hacky way of setting track_running_stats=False for the x2 input and reset the running stats manually.
Here is a small example:

bn = nn.BatchNorm2d(3, track_running_stats=True)
print(bn.running_mean) # zeros
print(bn.running_var) # ones

x = torch.randn(2, 3, 4, 4)
out = bn(x)
print(out.mean()) # should NOT be perfectly normal, since running stats used
print(bn.running_mean)  # not updated

# Disable running_states
# internally buffers are still valid, so we need to reset them manually
bn.track_running_stats = False
out2 = bn(x)
print(out2.mean()) # should be normal now
print(bn.running_mean) # unfortunately updated

# Reset running stats
with torch.no_grad():
    bn.running_mean = (bn.running_mean - x.mean([0, 2, 3]) * bn.momentum) / (1 - bn.momentum)
    bn.running_var =  (bn.running_var - x.var([0, 2, 3]) * bn.momentum) / (1 - bn.momentum)
print(bn.running_mean) # back to values before last update

# enable running stats again
bn.track_running_stats = True
out3 = bn(x)
print(out3.mean() == out.mean()) # compare to initial output
print(out3.std() == out.std())

I would recommend to test this approach in your model and make sure you’ll get the desired outputs, gradients and updates.

Thanks a lot, I would try.