Description of Conv2d groups parameter seems inconsistent with results

The pytorch docs for the groups parameter of nn.Conv2d state that:

groups controls the connections between inputs and outputs. in_channels and out_channels must both be divisible by groups. For example,

At groups=1, all inputs are convolved to all outputs.

At groups=2, the operation becomes equivalent to having two conv layers side by side, each seeing half the input channels, and producing half the output channels, and both subsequently concatenated.

At groups= in_channels, each input channel is convolved with its own set of filters, of size: in_channels / out_channels

However, this description seems inconsistent with the behaviour of nn.Conv2d in reality.

for example:

import torch
import torch.nn as nn
conv_layer = nn.Conv2d(16, 16, 1, groups=2, bias=False)
conv_layer.weight.shape

Returns torch.Size([16, 8, 1, 1])

But based on my interpretation of :

At groups=2, the operation becomes equivalent to having two conv layers side by side, each seeing half the input channels, and producing half the output channels, and both subsequently concatenated.

Shouldn’t there be two weights, each of size [8, 8, 1, 1]?

These inconsistencies carry over to other values for the groups parameter, in my view.

I must be missing something - could someone please clarify?

There are two groups of this size concatenated in dim0.
Have a look at this small example:

conv_grouped = nn.Conv2d(16, 16, 1, groups=2, bias=False)
output_grouped = conv_grouped(x)
output_grouped.shape

conv1 = nn.Conv2d(8, 8, 1, bias=False)
conv2 = nn.Conv2d(8, 8, 1, bias=False)
with torch.no_grad():
    conv1.weight.copy_(conv_grouped.weight[:8])
    conv2.weight.copy_(conv_grouped.weight[8:])

output1 = conv1(x[:, :8])
output2 = conv2(x[:, 8:])
output_manual = torch.cat((output1, output2), dim=1)

print((output_manual == output_grouped).all())
> tensor(1, dtype=torch.uint8)

As you can see, the first half of the conv_grouped weights will be applied to the first half of the channels in x and the second half of the weights to the second half of the channels.

1 Like

Ah thanks so much that’s perfectly clear now!