# Logic of Group Formation in Group Normalization

For my purposes I want to apply Group Normalization to three specific channel groups of a tensor. This groups are contiguous so for a Tensor with dimensions (N,C,H,W). Group 1 would be (N, : C /3,H,W), Group 2 (N, C/3: 2C/3, H,W) and Group 3 (N; 2C/3 : C, H , W). I tried looking in the code base but did not find what the logic of group creation is.

Can someone explain to me the logic for the formation of groups in GroupNorm? Does it match my contiguous scenario?

Your assumption seems to be correct based on this small experiment:

N, H, W = 16, 224, 224

# increasing groups
norm = nn.GroupNorm(num_groups=3, num_channels=6)
x = torch.cat((
torch.randn(N, 2, H, W) * 2 + 5,
torch.randn(N, 2, H, W) * 4 + 10,
torch.randn(N, 2, H, W) * 6 + 15,
), dim=1)

out = norm(x)
for i, o in enumerate(out.split(1, dim=1)):
print("channel {}: mean: {:.5f}, {:.5f}".format(i, o.mean(), o.std()))
> channel 0: mean: 0.00021, 0.99910
channel 1: mean: -0.00021, 1.00089
channel 2: mean: -0.00071, 0.99964
channel 3: mean: 0.00071, 1.00036
channel 4: mean: -0.00172, 1.00064
channel 5: mean: 0.00172, 0.99936

# mix channels
norm = nn.GroupNorm(num_groups=3, num_channels=6)
x = torch.cat((
torch.randn(N, 1, H, W) * 2 + 5,
torch.randn(N, 1, H, W) * 1/2 - 5,
torch.randn(N, 1, H, W) * 4 + 10,
torch.randn(N, 1, H, W) * 1/4 - 10,
torch.randn(N, 1, H, W) * 6 + 15,
torch.randn(N, 1, H, W) * 1/6 - 15,
), dim=1)

out = norm(x)
for i, o in enumerate(out.split(1, dim=1)):
print("channel {}: mean: {:.5f}, {:.5f}".format(i, o.mean(), o.std()))
> channel 0: mean: 0.96000, 0.38414
channel 1: mean: -0.96000, 0.09611
channel 2: mean: 0.96217, 0.38453
channel 3: mean: -0.96217, 0.02406
channel 4: mean: 0.96211, 0.38544
channel 5: mean: -0.96211, 0.01068

As you can see, the first approach gives a standardized (zero mean, unit variance) output per channel, while the second approach using the mixed stats does not.

Thank you for your reply. Yes, it seems almost certain the group formation is done for contiguous channels.