Convolution operator with groups for taking each convolved result

I recently saw that there is an option of groups for Conv2d. But I don’t understand how to use it.

In my case, I want to take each convolution output.

In traditional convolution networks, if we assume that input channel (R,G,B) is 3 and output channel (J1, J2) is 2,

J1 = F1 * R + F2 * G + F3 * B
J1 = F4 * R + F5 * G + F6 * B

where F1,F2,… denotes filters (say 3x3 filters) and * denotes convolution operator.

In my case, I want to manipulate the each convolved result. So I want to take the followings, not output images:

F1 * R, F2 * G, F3 * B,
F4 * R, F5 * G, F6 * B

Is this possible by using group option? Thanks.

4 Likes

Yes, nn.Conv2d(3, 6, groups=6) will do the operation you want. Each output channel will be a result of a convolution over a single input channel.

4 Likes

Thanks you for kind reply!

Minor question: Is this right that nn.Conv2d(3, 6, groups=3)? This is because I saw the following in ConvNd module,

       self.weight = Parameter(torch.Tensor(
            out_channels, in_channels // groups, *kernel_size))

Yes, you’re right. It should be groups=3.

2 Likes

Thanks to this Q&A, it helps me to understand groups operation.

3 Likes

@apaszke, thank you for confirming the feature. As a more general question, when is it interesting to convolve channels separately like this as opposed to the default groups=1?

In a backprop context, you’d use groupwise convolution to either make the model splittable into more than one GPU (since the groups won’t have to “cross” data), or to force the training to split channels into different groupings of features (you can google on this, there are some nice explanations available). In a more general context, you could find any number of applications, lets say for simplicity that you just want to blur each R, G and B channel separately with a blur kernel and want RGB output (and you definitely don’t want to blur R with G etc.).

Quick follow-up: What is the order of the channels in the output image?

Assuming input is RGB image and output is 6 channels (as in OP) is the order:

Guess 1:
J1 = F1 * R
J2 = F2 * G
J3 = F3 * B
J4 = F4 * R
J5 = F5 * G
J6 = F6 * B

Or
Guess 2:
J1 = F1 * R
J2 = F2 * R
J3 = F3 * G
J4 = F4 * G
J5 = F5 * B
J6 = F6 * B

1 Like

Let’s create a small example and have a look at the order:

# Set R=100, G=200, B=300
x = torch.FloatTensor([100, 200, 300]).view(1, -1, 1, 1)
x = Variable(x)

conv = nn.Conv2d(in_channels=3,
                 out_channels=6,
                 kernel_size=1,
                 stride=1,
                 padding=0,
                 groups=3,
                 bias=False)

# Set Conv weight to [0, 1, 2, 3, 4 ,5]
conv.weight.data = torch.arange(6).view(-1, 1, 1, 1)
output = conv(x)
print(output)
> Variable containing:
(0 ,0 ,.,.) = 
     0

(0 ,1 ,.,.) = 
   100

(0 ,2 ,.,.) = 
   400

(0 ,3 ,.,.) = 
   600

(0 ,4 ,.,.) = 
  1200

(0 ,5 ,.,.) = 
  1500
[torch.FloatTensor of size (1,6,1,1)]

The only way to get this result is the second guess.

J1 = F1 * R = 0 * 100 = 0
J2 = F2 * R = 1 = 100 = 100
J3 = F3 * G = 2 * 200 = 400
J4 = F4 * G = 3 * 200 = 600
J5 = F5 * B = 4 * 300 = 1200
J6 = F6 * B = 5 * 300 = 1500

12 Likes

Thanks so much, this is super helpful!

I try the following codes:

import torch
from torch import nn

# Set (RGB)*2 = (0, 100, 200), (300, 400, 500)
x = (torch.arange(6, dtype=torch.double)*100).view(1, -1, 1, 1)

conv = nn.Conv2d(in_channels=6,
                 out_channels=2,
                 kernel_size=1,
                 stride=1,
                 padding=0,
                 groups=2,
                 bias=False)

# Set Conv weight to 1
conv.weight.data = torch.ones(6, dtype=torch.double).view(2, 3, 1, 1)

output = conv(x)

print(output)

output = (500, 5000). The result seems to be your Guess 2, RR GG BB.