To better understand how nn.BatchNorm2d
works, I wanted to recreate the following lines of code:
input = torch.randint(1, 5, size=(2, 2, 3, 3)).float()
batch_norm = nn.BatchNorm2d(2)
output = (batch_norm(input))
print(input)
print(output)
tensor([[[[3., 3., 2.],
[3., 2., 2.],
[4., 2., 1.]],
[[1., 1., 2.],
[3., 2., 1.],
[1., 1., 1.]]],
[[[4., 3., 3.],
[4., 1., 4.],
[1., 3., 2.]],
[[2., 1., 4.],
[4., 2., 1.],
[4., 1., 3.]]]])
tensor([[[[ 0.3859, 0.3859, -0.6064],
[ 0.3859, -0.6064, -0.6064],
[ 1.3783, -0.6064, -1.5988]],
[[-0.8365, -0.8365, 0.0492],
[ 0.9349, 0.0492, -0.8365],
[-0.8365, -0.8365, -0.8365]]],
[[[ 1.3783, 0.3859, 0.3859],
[ 1.3783, -1.5988, 1.3783],
[-1.5988, 0.3859, -0.6064]],
[[ 0.0492, -0.8365, 1.8206],
[ 1.8206, 0.0492, -0.8365],
[ 1.8206, -0.8365, 0.9349]]]]
To achieve, I first calculated the mean and variance for each channel:
my_mean = (torch.mean(input, dim=[0, 2, 3]))
my_var = (torch.var(input, dim=[0, 2, 3]))
print(my_mean, my_var)
tensor([2.8333, 2.5556])
tensor([1.3235, 1.3203])
This seems reasonable, I have the mean and variance for each channel across the whole batch. Then I wanted to simply extract the mean from the input and divide by the variance. This is where problems arise, since I do not know to properly set up the mean and variance. PyTorch
does not seem to broadcast properly:
my_output = (input - my_mean) / my_var
RuntimeError: The size of tensor a (3) must match the size of tensor b (2) at non-singleton dimension 3
I then wanted to reshape the mean and variance in the appropriate shape, such that each value is repeated 25
times in a 5x5
shape
First try:
my_mean.repeat(25).reshape(3, 5, 5)
But this also results in an error. What is the best way to achieve my goal?