Checkpoint with BatchNorm running averages

Hi, I was wondering if I would use torch.utils.checkpoint.checkpoint on a module that includes BatchNorm, then how will it deal with the running mean/variance? If the BatchNorm would be calculated twice (once during the forward pass and once during recomputation in the backward pass), then I see two problems:

  • The running mean/variance gets updated twice… however this effect may be negligible?
  • During the backward pass, the result of BatchNorm is different than during the forward pass… it seems to me that this results in the gradient computation being wrong?

To me, a solution would be to either make a copy of the running mean/variance in the forward pass and use this copy in the backward pass, or only update the running mean/variance during the backward pass, but implementation of this seems cumbersome. Any ideas?

Note: I did not try it yet as I have to migrate to 0.4 first. Also I found LayerNorm not working well in my situation.

1 Like

If BN is in training mode, then backward/forward uses batch stats and the result will be the same. If BN is in eval mode, then they both use fixed running stats and result will still be the same.

Hi, I think his point is that the running mean/variance gets updated twice because of the moving average.
So the grad may be wrong in back prop.

Yes it may update twice. My point is that even if it is updated twice, grad won’t be wrong.

Thanks for your reply.
I have tried compare BN in checkpoint and original BN.
It’s true that the grad is correct but running_mean and var were updated twice.
I can’t figure out why the grads won’t be wrong.

1 Like
class BatchNorm3dFix(nn.BatchNorm3d):
    """
    In this way, the stat values can be correct now.
    """
    def __init__(self, *args, **kwargs):
        super(BatchNorm3dFix, self).__init__(*args, **kwargs)
        self.prev_running_mean = self.running_mean.new(self.running_mean.size())
        self.prev_running_var = self.running_var.new(self.running_var.size())

    def forward(self, input, in_backward=False):
        self._check_input_dim(input)
        if in_backward:
            self.running_mean.copy_(self.prev_running_mean)
            self.running_var.copy_(self.prev_running_var)
        else:
            self.prev_running_mean.copy_(self.running_mean)
            self.prev_running_var.copy_(self.running_var)

        return F.batch_norm(
            input, self.running_mean, self.running_var, self.weight, self.bias,
            self.training or not self.track_running_stats,
            self.momentum, self.eps)


class _ConvFunction(object):
    def __init__(self, norm, relu, conv):
        """
        function for checkpoint
        Concat(if needed) -> BatchNorm -> ReLU -> Conv
        """
        self.norm = norm
        self.relu = relu
        self.conv = conv
        self.in_backward = False

    def __call__(self, *inputs):
        concated_features = torch.cat(inputs, 1) if len(inputs) > 1 else inputs[0]
        output = self.conv(self.relu(self.norm(concated_features, self.in_backward)))
        self.in_backward = True
        return output

Hi, @SimonW of course, you are right. I was confused thinking that the running averages were used in the computation during training, but that is only during evaluation (in which case you need no grad). So, if you ask me, the grad will be correct as the batch statistics are calculated twice and won’t change.

However, the running average will indeed be updated twice. This means, if we substitute alpha = 1 - momentum:
Single update: x_avg <- x_avg * (1-momentum) + x_hat * momentum = x_avg * alpha + x_hat * (1 - alpha)
Double update: x_avg <- (x_avg * alpha + x_hat * (1 - alpha)) * alpha + x_hat * (1 - alpha) = x_avg(n) * alpha^2 + x_hat * (1 - alpha) * (alpha + 1) = x_avg(n) * alpha^2 + x_hat (1 - alpha^2).

I think therefore we can correct the effect of updating twice by taking alpha_corr = sqrt(alpha) or momentum_corr = 1 - alpha_corr = 1 - sqrt(alpha) = 1 - sqrt(1 - momentum), but correct me if I’m wrong. To achieve an effective momentum of 0.1 (the default), this means you should use 1 - sqrt(0.9) ~= 0.051, although I assume the effect will be very small. A solution like @11173 may work as well, although it may be easier to just pass in a dummy running_mean/running_var placeholder the second time, instead of copying back and forth?