Conv2d groups=2 on single GPU

I searched on the forum, but couldn’t find any topic with exactly same problem described below.

Currently, I do some research on computer vision deep learning topic.
I noticed, that in “Conv2d”, there is an option to specify a “groups” value.
According to PyTorch docs:
At groups=2, the operation becomes equivalent to having two conv layers side by side, each seeing half the input channels, and producing half the output channels, and both subsequently concatenated.

If I’m right, groups can also be used to split the load on many GPUs, if groups=2, split on 2 GPUs (original AlexNet was trained that way on many less powerfull GPUs)

My question is, what happens, if I specify groups=2, but have only a single GPU? Is the half of the channels from the output of the previous layer dropped in case of single GPU, and information is lost? Or does pyTorch concatenates the two groups and feeds to the next layer as if there weren’t any groups (or groups=1) ?

I train my network currently on Google Colab with GPU, and as far as I know, colab assigns just one GPU to the session.

Excerpt from models state dict (conv2, conv4, conv5 with groups=2):

conv.conv1_s1.weight torch.Size([96, 3, 25, 7])
conv.conv1_s1.bias torch.Size([96])
conv.conv2_s1.weight torch.Size([192, 48, 3, 3])
conv.conv2_s1.bias torch.Size([192])
conv.conv3_s1.weight torch.Size([256, 192, 3, 3])
conv.conv3_s1.bias torch.Size([256])
conv.conv4_s1.weight torch.Size([256, 128, 3, 3])
conv.conv4_s1.bias torch.Size([256])
conv.conv5_s1.weight torch.Size([192, 128, 3, 3])
conv.conv5_s1.bias torch.Size([192])
fc6.fc6_s1.weight torch.Size([512, 6336])
fc6.fc6_s1.bias torch.Size([512])
fc7.fc7.weight torch.Size([256, 2048])
fc7.fc7.bias torch.Size([256])
classifier.fc8.weight torch.Size([24, 256])
classifier.fc8.bias torch.Size([24])
Total parameters: 4868056

Many Thanks

You could achieve the same output by splitting the input to two GPUs, calling each conv individually, and concatenating the output, as with groups=2.
However, groups=2 will not split the conv to two GPUs. Instead it will use the explained method as if you are using two conv layers side by side. The operation will still use a single GPU.

Actually, my goal is not to split on 2 GPUs. I tested my network with groups=2 and groups=1, and wondering if I loose some information when applying groups=2, because printing the state dict (if I do it correctly) shows in the second layer only half of the channels from the first layer.

That means with groups=2:

  1. layer output channels : 96 --> 2.layer input channels : 48 (see first post)

But with groups=1:

  1. layer output channels : 96 --> 2.layer input channels : 96 (as expected)

I print the state dict this way:

    for param_tensor in net.state_dict():
      print(param_tensor, "\t", net.state_dict()[param_tensor].size())

So the main question from my perspective is, if I loose some information when using groups=2 ?

If you are using groups=2, then each filter will have 0.5*in_channels, as the operation is comparable to two separate convolutions using half of the input.

This also means that the number of parameters in the conv layer will be lowered.

The input will still be processes by one or the other group, so the convolution will still use the complete input information.
If you refer to the capacity (number of parameters) as information, then the grouped conv will have less capacity.

Do you really mean “or” ? How can it be if the convolution uses the complete input information ? Just to be on the safe side, both groups are processed, right ? (Sorry for repeated question, I’m a newbie)

Thanks

Sorry for the misleading message.
The input will completely processed. Each group will see half the input channels, and producing half the output channels, and both subsequently concatenated (from the docs).

This code demonstrates the behavior better:

conv = nn.Conv2d(6, 6, 3, 1, 1, groups=2, bias=False)
x = torch.randn(1, 6, 24, 24)
out = conv(x)

filters1 = conv.weight[:3]
o1 = F.conv2d(x[:, :3], filters1, padding=1)

filters2 = conv.weight[3:]
o2 = F.conv2d(x[:, 3:], filters2, padding=1)

out_manual = torch.cat((o1, o2), dim=1)

print((out - out_manual).abs().max())
> tensor(3.5763e-07, grad_fn=<MaxBackward1>)
1 Like