# Batchnorm and back-propagation

Hi All,

I have what I hope to be a simple question - when mu and variance are calculated in the batchnorm layer, are the gradients propagated to the scaling? I.e., are the mu and var in y = (x - mu) / sqrt(var + eps) simple numbers or the gradient tracked tensors?

I’m asking because I want to implement a modified version of batchnorm using the variance of the dimension before dropout is applied. I need to know if I need to call .detach() or not on mu and var.

Thanks!
Andy

The running stats are registered as buffers and do not require gradients, so you shouldn’t need to call `.detach()` on them.

Does batch norm in training mode normalize to the current batch statistics, or to the current running stats? I assume the former because I have a case were batch statistics are actually tied to a bag of instanced so they aren’t comparable across bags (so I’m using running stats = False).

Yes, during training the batch statistics will be used to normalize the current input.

Thanks for the responses so far.

Hopefully this makes my initial question clear: Say you just want to normalize to 0 mean. The operation is y = (x - mu).

Is the ‘batchnorm equivalent’ approach like:

``````y = x - x.mean(dim=0).detach()
``````

or

``````y = x - x.mean(dim=0)
``````

While the value of y is the same, the gradient propagation is clearly different.

I wasn’t sure, but based on this small code snippet, it seems the latter approach is used:

``````
# Manual without detach
torch.manual_seed(2809)
x = torch.randn(10, 3, 4, 4, requires_grad=True)
mean = x.mean(dim=[0, 2, 3], keepdim=True)
invstd = torch.sqrt(x.var([0, 2, 3], unbiased=False, keepdim=True) + 1e-5)
y = (x - mean) / invstd
y.abs().sum().backward()
y1 = y.clone()

# Manual with detach
torch.manual_seed(2809)
x = torch.randn(10, 3, 4, 4, requires_grad=True)
mean = x.mean(dim=[0, 2, 3], keepdim=True).detach()
invstd = torch.sqrt(x.var([0, 2, 3], unbiased=False, keepdim=True).detach() + 1e-5)
y = (x - mean) / invstd
y.abs().sum().backward()
y2 = y.clone()

# BN
torch.manual_seed(2809)
x = torch.randn(10, 3, 4, 4, requires_grad=True)
bn = nn.BatchNorm2d(3, affine=False)
y = bn(x)
y.abs().sum().backward()
y3 = y.clone()

# Compare
print((y3-y1).abs().max())