Hi, I was wondering if I would use torch.utils.checkpoint.checkpoint on a module that includes BatchNorm, then how will it deal with the running mean/variance? If the BatchNorm would be calculated twice (once during the forward pass and once during recomputation in the backward pass), then I see two problems:

The running mean/variance gets updated twice… however this effect may be negligible?

During the backward pass, the result of BatchNorm is different than during the forward pass… it seems to me that this results in the gradient computation being wrong?

To me, a solution would be to either make a copy of the running mean/variance in the forward pass and use this copy in the backward pass, or only update the running mean/variance during the backward pass, but implementation of this seems cumbersome. Any ideas?

Note: I did not try it yet as I have to migrate to 0.4 first. Also I found LayerNorm not working well in my situation.

If BN is in training mode, then backward/forward uses batch stats and the result will be the same. If BN is in eval mode, then they both use fixed running stats and result will still be the same.

Thanks for your reply.
I have tried compare BN in checkpoint and original BN.
It’s true that the grad is correct but running_mean and var were updated twice.
I can’t figure out why the grads won’t be wrong.

Hi, @SimonW of course, you are right. I was confused thinking that the running averages were used in the computation during training, but that is only during evaluation (in which case you need no grad). So, if you ask me, the grad will be correct as the batch statistics are calculated twice and won’t change.

However, the running average will indeed be updated twice. This means, if we substitute alpha = 1 - momentum:
Single update: x_avg <- x_avg * (1-momentum) + x_hat * momentum = x_avg * alpha + x_hat * (1 - alpha)
Double update: x_avg <- (x_avg * alpha + x_hat * (1 - alpha)) * alpha + x_hat * (1 - alpha) = x_avg(n) * alpha^2 + x_hat * (1 - alpha) * (alpha + 1) = x_avg(n) * alpha^2 + x_hat (1 - alpha^2).

I think therefore we can correct the effect of updating twice by taking alpha_corr = sqrt(alpha) or momentum_corr = 1 - alpha_corr = 1 - sqrt(alpha) = 1 - sqrt(1 - momentum), but correct me if I’m wrong. To achieve an effective momentum of 0.1 (the default), this means you should use 1 - sqrt(0.9) ~= 0.051, although I assume the effect will be very small. A solution like @11173 may work as well, although it may be easier to just pass in a dummy running_mean/running_var placeholder the second time, instead of copying back and forth?