Logic of Group Formation in Group Normalization

For my purposes I want to apply Group Normalization to three specific channel groups of a tensor. This groups are contiguous so for a Tensor with dimensions (N,C,H,W). Group 1 would be (N, : C /3,H,W), Group 2 (N, C/3: 2C/3, H,W) and Group 3 (N; 2C/3 : C, H , W). I tried looking in the code base but did not find what the logic of group creation is.

Can someone explain to me the logic for the formation of groups in GroupNorm? Does it match my contiguous scenario?

Your assumption seems to be correct based on this small experiment:

N, H, W = 16, 224, 224

# increasing groups
norm = nn.GroupNorm(num_groups=3, num_channels=6)
x = torch.cat((
    torch.randn(N, 2, H, W) * 2 + 5,
    torch.randn(N, 2, H, W) * 4 + 10,
    torch.randn(N, 2, H, W) * 6 + 15,
), dim=1)

out = norm(x)
for i, o in enumerate(out.split(1, dim=1)):
    print("channel {}: mean: {:.5f}, {:.5f}".format(i, o.mean(), o.std()))
> channel 0: mean: 0.00021, 0.99910
  channel 1: mean: -0.00021, 1.00089
  channel 2: mean: -0.00071, 0.99964
  channel 3: mean: 0.00071, 1.00036
  channel 4: mean: -0.00172, 1.00064
  channel 5: mean: 0.00172, 0.99936

# mix channels
norm = nn.GroupNorm(num_groups=3, num_channels=6)
x = torch.cat((
    torch.randn(N, 1, H, W) * 2 + 5,
    torch.randn(N, 1, H, W) * 1/2 - 5,
    torch.randn(N, 1, H, W) * 4 + 10,
    torch.randn(N, 1, H, W) * 1/4 - 10,
    torch.randn(N, 1, H, W) * 6 + 15,
    torch.randn(N, 1, H, W) * 1/6 - 15,
), dim=1)

out = norm(x)
for i, o in enumerate(out.split(1, dim=1)):
    print("channel {}: mean: {:.5f}, {:.5f}".format(i, o.mean(), o.std()))
> channel 0: mean: 0.96000, 0.38414
  channel 1: mean: -0.96000, 0.09611
  channel 2: mean: 0.96217, 0.38453
  channel 3: mean: -0.96217, 0.02406
  channel 4: mean: 0.96211, 0.38544
  channel 5: mean: -0.96211, 0.01068

As you can see, the first approach gives a standardized (zero mean, unit variance) output per channel, while the second approach using the mixed stats does not.

Thank you for your reply. Yes, it seems almost certain the group formation is done for contiguous channels.