The shape of mean and variance in torch.nn.functional.instance_norm does not make sense

Consider the following code:

x = torch.rand((2, 3, 4, 5)) #N,C,H,W
x1_running_mean = torch.zeros((3))
x1_running_var = torch.ones((3))
x2_running_mean = torch.zeros((2, 3))
x2_running_var = torch.ones((2, 3))
res1 = torch.nn.functional.instance_norm(x, running_mean = x1_running_mean, running_var = x1_running_var, use_input_stats = False)
res2 = torch.nn.functional.instance_norm(x, running_mean = x2_running_mean, running_var = x2_running_var, use_input_stats = False)

res1 runs without any errors while res2 fails with the error saying that the function expected a tensor of size (3).

For instance norm, the mean, and variance are calculated for each N,C i.e across all the spatial dimensions and therefore the size of the mean and variance should be (N,C) and not (C) . Does pytorch handle instance norm differently?

Wouldn’t this add a dependency on the batch size which would then crash is the number of samples changes during training or eval?
Using the module I see these shapes which match your working example:

m = nn.InstanceNorm2d(100, affine=True, track_running_stats=True)
input = torch.randn(20, 100, 35, 45)
output = m(input)

# torch.Size([100])
# torch.Size([100])
# torch.Size([100])
# torch.Size([100])

You are right about the dependency on the batch size, I hadn’t realized this. I found the following numpy implementation in pytorch tests which made me understand how the mean and variance is calculated across N & C channels:

def ref_nchw(x, scale, bias):

    x = x.reshape(batch_size * input_channels, size_a * size_b)
    y = (x - x.mean(1)[:, np.newaxis])
    y /= np.sqrt(x.var(1) + epsilon)[:, np.newaxis]
    y = y.reshape(batch_size, input_channels, size_a, size_b)
    y = y * scale.reshape(1, input_channels, 1, 1)
    y = y + bias.reshape(1, input_channels, 1, 1)

1 Like