Batchnorm and back-propagation

Hi All,

I have what I hope to be a simple question - when mu and variance are calculated in the batchnorm layer, are the gradients propagated to the scaling? I.e., are the mu and var in y = (x - mu) / sqrt(var + eps) simple numbers or the gradient tracked tensors?

I’m asking because I want to implement a modified version of batchnorm using the variance of the dimension before dropout is applied. I need to know if I need to call .detach() or not on mu and var.

Thanks!
Andy

The running stats are registered as buffers and do not require gradients, so you shouldn’t need to call .detach() on them.

Does batch norm in training mode normalize to the current batch statistics, or to the current running stats? I assume the former because I have a case were batch statistics are actually tied to a bag of instanced so they aren’t comparable across bags (so I’m using running stats = False).

Yes, during training the batch statistics will be used to normalize the current input.

Thanks for the responses so far.

Hopefully this makes my initial question clear: Say you just want to normalize to 0 mean. The operation is y = (x - mu).

Is the ‘batchnorm equivalent’ approach like:

y = x - x.mean(dim=0).detach()

or

y = x - x.mean(dim=0)

While the value of y is the same, the gradient propagation is clearly different.

I wasn’t sure, but based on this small code snippet, it seems the latter approach is used:


# Manual without detach
torch.manual_seed(2809)
x = torch.randn(10, 3, 4, 4, requires_grad=True)
mean = x.mean(dim=[0, 2, 3], keepdim=True)
invstd = torch.sqrt(x.var([0, 2, 3], unbiased=False, keepdim=True) + 1e-5)
y = (x - mean) / invstd 
y.abs().sum().backward()
print(x.grad)
x1_grad = x.grad.clone()
y1 = y.clone()

# Manual with detach
torch.manual_seed(2809)
x = torch.randn(10, 3, 4, 4, requires_grad=True)
mean = x.mean(dim=[0, 2, 3], keepdim=True).detach()
invstd = torch.sqrt(x.var([0, 2, 3], unbiased=False, keepdim=True).detach() + 1e-5)
y = (x - mean) / invstd
y.abs().sum().backward()
print(x.grad)
x2_grad = x.grad.clone()
y2 = y.clone()

# BN
torch.manual_seed(2809)
x = torch.randn(10, 3, 4, 4, requires_grad=True)
bn = nn.BatchNorm2d(3, affine=False)
y = bn(x)
y.abs().sum().backward()
print(x.grad)
x3_grad = x.grad.clone()
y3 = y.clone()

# Compare
print((y3-y1).abs().max())
> tensor(2.3842e-07, grad_fn=<MaxBackward1>)
print((y3-y2).abs().max())
> tensor(2.3842e-07, grad_fn=<MaxBackward1>)

print((x3_grad - x1_grad).abs().max())
> tensor(8.3447e-07)
print((x3_grad - x2_grad).abs().max())
> tensor(2.6167)
4 Likes

Awesome! This makes sense to me, and figured this was the case, and glad to know for sure. Thank you so much for helping clear this up!

1 Like

The gradients should not be detached. This is actually explained in the 2nd page of the original batchnorm paper. imagine the loss function “wants” to increase the value of a batchnormed activation because of a bias in the targets (i.e. independent of the input to the network), if you detach the mean, then the gradients will cause the pre-normed activation to increase all across the batch, causing the difference between the sample value and the batch mean to remain constant, causing a non-stop drift of the pre-normed activations. they explain it better in the paper: “As the training continues, b will grow indefinitely while the loss remains fixed. This problem can get worse if the normalization not only centers but also scales the activations. We have observed this empirically in initial experiments, where the model blows up when the normalization parameters are computed outside the gradient descent step”