Batchnorm and back-propagation

Hi All,

I have what I hope to be a simple question - when mu and variance are calculated in the batchnorm layer, are the gradients propagated to the scaling? I.e., are the mu and var in y = (x - mu) / sqrt(var + eps) simple numbers or the gradient tracked tensors?

I’m asking because I want to implement a modified version of batchnorm using the variance of the dimension before dropout is applied. I need to know if I need to call .detach() or not on mu and var.

Thanks!
Andy

The running stats are registered as buffers and do not require gradients, so you shouldn’t need to call .detach() on them.

Does batch norm in training mode normalize to the current batch statistics, or to the current running stats? I assume the former because I have a case were batch statistics are actually tied to a bag of instanced so they aren’t comparable across bags (so I’m using running stats = False).

Yes, during training the batch statistics will be used to normalize the current input.

Thanks for the responses so far.

Hopefully this makes my initial question clear: Say you just want to normalize to 0 mean. The operation is y = (x - mu).

Is the ‘batchnorm equivalent’ approach like:

y = x - x.mean(dim=0).detach()

or

y = x - x.mean(dim=0)

While the value of y is the same, the gradient propagation is clearly different.

I wasn’t sure, but based on this small code snippet, it seems the latter approach is used:


# Manual without detach
torch.manual_seed(2809)
x = torch.randn(10, 3, 4, 4, requires_grad=True)
mean = x.mean(dim=[0, 2, 3], keepdim=True)
invstd = torch.sqrt(x.var([0, 2, 3], unbiased=False, keepdim=True) + 1e-5)
y = (x - mean) / invstd 
y.abs().sum().backward()
print(x.grad)
x1_grad = x.grad.clone()
y1 = y.clone()

# Manual with detach
torch.manual_seed(2809)
x = torch.randn(10, 3, 4, 4, requires_grad=True)
mean = x.mean(dim=[0, 2, 3], keepdim=True).detach()
invstd = torch.sqrt(x.var([0, 2, 3], unbiased=False, keepdim=True).detach() + 1e-5)
y = (x - mean) / invstd
y.abs().sum().backward()
print(x.grad)
x2_grad = x.grad.clone()
y2 = y.clone()

# BN
torch.manual_seed(2809)
x = torch.randn(10, 3, 4, 4, requires_grad=True)
bn = nn.BatchNorm2d(3, affine=False)
y = bn(x)
y.abs().sum().backward()
print(x.grad)
x3_grad = x.grad.clone()
y3 = y.clone()

# Compare
print((y3-y1).abs().max())
> tensor(2.3842e-07, grad_fn=<MaxBackward1>)
print((y3-y2).abs().max())
> tensor(2.3842e-07, grad_fn=<MaxBackward1>)

print((x3_grad - x1_grad).abs().max())
> tensor(8.3447e-07)
print((x3_grad - x2_grad).abs().max())
> tensor(2.6167)
4 Likes

Awesome! This makes sense to me, and figured this was the case, and glad to know for sure. Thank you so much for helping clear this up!

1 Like