Your assumption seems to be correct based on this small experiment:
N, H, W = 16, 224, 224
# increasing groups
norm = nn.GroupNorm(num_groups=3, num_channels=6)
x = torch.cat((
torch.randn(N, 2, H, W) * 2 + 5,
torch.randn(N, 2, H, W) * 4 + 10,
torch.randn(N, 2, H, W) * 6 + 15,
), dim=1)
out = norm(x)
for i, o in enumerate(out.split(1, dim=1)):
print("channel {}: mean: {:.5f}, {:.5f}".format(i, o.mean(), o.std()))
> channel 0: mean: 0.00021, 0.99910
channel 1: mean: -0.00021, 1.00089
channel 2: mean: -0.00071, 0.99964
channel 3: mean: 0.00071, 1.00036
channel 4: mean: -0.00172, 1.00064
channel 5: mean: 0.00172, 0.99936
# mix channels
norm = nn.GroupNorm(num_groups=3, num_channels=6)
x = torch.cat((
torch.randn(N, 1, H, W) * 2 + 5,
torch.randn(N, 1, H, W) * 1/2 - 5,
torch.randn(N, 1, H, W) * 4 + 10,
torch.randn(N, 1, H, W) * 1/4 - 10,
torch.randn(N, 1, H, W) * 6 + 15,
torch.randn(N, 1, H, W) * 1/6 - 15,
), dim=1)
out = norm(x)
for i, o in enumerate(out.split(1, dim=1)):
print("channel {}: mean: {:.5f}, {:.5f}".format(i, o.mean(), o.std()))
> channel 0: mean: 0.96000, 0.38414
channel 1: mean: -0.96000, 0.09611
channel 2: mean: 0.96217, 0.38453
channel 3: mean: -0.96217, 0.02406
channel 4: mean: 0.96211, 0.38544
channel 5: mean: -0.96211, 0.01068
As you can see, the first approach gives a standardized (zero mean, unit variance) output per channel, while the second approach using the mixed stats does not.