Broadcasting in PyTorch

To better understand how nn.BatchNorm2d works, I wanted to recreate the following lines of code:

    input = torch.randint(1, 5, size=(2, 2, 3, 3)).float()
    batch_norm = nn.BatchNorm2d(2)
    output = (batch_norm(input))
    print(input)
    print(output)

    tensor([[[[3., 3., 2.],
              [3., 2., 2.],
              [4., 2., 1.]],
    
             [[1., 1., 2.],
              [3., 2., 1.],
              [1., 1., 1.]]],
    
    
            [[[4., 3., 3.],
              [4., 1., 4.],
              [1., 3., 2.]],
    
             [[2., 1., 4.],
              [4., 2., 1.],
              [4., 1., 3.]]]])
    tensor([[[[ 0.3859,  0.3859, -0.6064],
              [ 0.3859, -0.6064, -0.6064],
              [ 1.3783, -0.6064, -1.5988]],
    
             [[-0.8365, -0.8365,  0.0492],
              [ 0.9349,  0.0492, -0.8365],
              [-0.8365, -0.8365, -0.8365]]],
    
    
            [[[ 1.3783,  0.3859,  0.3859],
              [ 1.3783, -1.5988,  1.3783],
              [-1.5988,  0.3859, -0.6064]],
    
             [[ 0.0492, -0.8365,  1.8206],
              [ 1.8206,  0.0492, -0.8365],
              [ 1.8206, -0.8365,  0.9349]]]]

To achieve, I first calculated the mean and variance for each channel:

    my_mean = (torch.mean(input, dim=[0, 2, 3]))
    my_var = (torch.var(input, dim=[0, 2, 3]))
    print(my_mean, my_var)
    tensor([2.8333, 2.5556])
    tensor([1.3235, 1.3203])

This seems reasonable, I have the mean and variance for each channel across the whole batch. Then I wanted to simply extract the mean from the input and divide by the variance. This is where problems arise, since I do not know to properly set up the mean and variance. PyTorch does not seem to broadcast properly:

    my_output = (input - my_mean) / my_var
    RuntimeError: The size of tensor a (3) must match the size of tensor b (2) at non-singleton dimension 3

I then wanted to reshape the mean and variance in the appropriate shape, such that each value is repeated 25 times in a 5x5 shape

First try:

my_mean.repeat(25).reshape(3, 5, 5)

But this also results in an error. What is the best way to achieve my goal?

You could unsqueeze the stats tensors directly or add the missing dimensions via indexing with None as seen here:

input = torch.randint(1, 5, size=(2, 2, 3, 3)).float()
batch_norm = nn.BatchNorm2d(2)
output = (batch_norm(input))
print(input)
print(output)

mean = input.mean([0, 2, 3])
var = input.var([0, 2, 3], unbiased=False)

out = (input - mean[None, :, None, None]) / (torch.sqrt(var[None, :, None, None]) + batch_norm.eps)
print((out - output).abs().max())
# tensor(8.5831e-06, grad_fn=<MaxBackward1>)

Sorry for my late response, and thanks for your fast reponse :slight_smile:

So, the ‘None’ basically let me expand any scalar, vector or matrix to a arbitray dimension?

Yes, indexing with None will add a new dimension, as it’s an alias for np.newaxis:

np.newaxis is None
# True

and PyTorch adapted this API.