Convolution by group of channel

In my work, I need to split filter and input feature map by 4 along channel depth and do convolution individualy.
I want to ask how can I do to facilitate this procedure( maybe can get rid of “for loop”)
I have tried

  1. depth-wise convolution, but I can not reshape weight to correct order to ouput correct result.
  2. torch.multiprocessing, but it seems to be applicable only on those tensor without gradient?

My question is how to get every temp_output faster~thanks
code like this

def forward(self, input, weight, padding = 0, stride = 1):
        number = 4
        number_group = int(weight.size(1)/number)
        temp_output_group = [None]*number_group
        for i in range():
            channel_start = i*number
            channel_end = (i+1)*number
            temp_output = F.conv2d(input[:,channel_start:channel_end], weight[:,channel_start:channel_end], padding=padding, stride=stride)
            temp_output_group.append(temp_output)
        #do something with temp_output group
        return ...

Can you take a look at groups param in F.conv2d() and see if it will help your scenario?
Probably, you could call F.conv2d(input, weight, padding=padding, stride=stride, groups=4)?

Thanks for your reply.
I have tried groups parameter, it’s a good direction.
But it cannot give me the correct output, because reshape weight cannot give the correct weight order for this scenario.
Maybe I need to redefine a reshape function to support.

Would you be able to provide a small script to know how you verified it?

import torch
from torch import nn
import torch.nn.functional as F
input = torch.ones((2,4,2,2))
weight = torch.ones((2,4,2,2))
input[0][1] = input[0][1]*2
input[1][0] = input[1][0]*3
input[1][1] = input[1][1]*4
weight[0][1] = weight[0][1]*2
weight[1][0] = weight[1][0]*3
weight[1][1] = weight[1][1]*4
weight = weight.reshape((4,1,2,2)) 
result = F.conv2d(input, weight, padding=1, stride=1, groups=2)
print(input)
print(weight)
print(result)
print(result.shape)

I just write this to verify that the number multiplied is not what I want.
Because reshape is to make every channel and still reserve the order of filter
But I need reshape it with the same channel depth in the sub-group
Sorry, maybe I cannot explained it clearly.

Pl. refer to documentation of conv2d’s groups parameter here.
According to this, weight parameter shall have the dimension of out_channels x (in_channels / groups) x h x w.

In your example above,

input size = 2 x 4 x 2 x 2 (in_channels = 4)
weight size = 2 x 4 x 2 x 2 (out_channels = 2, in_channels / groups = 4)

When in_channels=4, groups=2, weight should have size of out_channels x 2 x h x w (i.e., in_channels / groups = 4 / 2 = 2).

import torch
from torch import nn
import torch.nn.functional as F

input = torch.ones((2,4,2,2))
weight = torch.ones((2,2,2,2))

input[0][1] = input[0][1]*2
input[1][0] = input[1][0]*3
input[1][1] = input[1][1]*4

weight[0][1] = weight[0][1]*2
weight[1][0] = weight[1][0]*3
weight[1][1] = weight[1][1]*4

result = F.conv2d(input, weight, padding=1, stride=1, groups=2)

print(input)
print(weight)
print(result)
print(result.shape)
import torch
from torch import nn
import torch.nn.functional as F
input = torch.ones((2,4,2,2))
weight = torch.ones((2,4,2,2))
input[0][1] = input[0][1]*2
input[1][0] = input[1][0]*3
input[1][1] = input[1][1]*4
weight[0][1] = weight[0][1]*2
weight[1][0] = weight[1][0]*3
weight[1][1] = weight[1][1]*4
weight = weight.reshape((4,2,2,2))  # this
result = F.conv2d(input, weight, padding=1, stride=1, groups=2)
print(input)
print(weight)
print(result)
print(result.shape)

Sorry for that I write the code wrongly.
I mean it can work but can not output the result desired for me.
Because reshape the weights, those weights reshaped are not in the order that I need in this scenario.

I am not sure whats the expected output is, in your case.
Either way, If you want to take advantage of parallelization from F.conv2d, you just have to organize your weights accordingly (with appropriate reshape), I guess.

Thanks for your patient responses.