Convolution that only take channel-wise summation?

I basically want to do element-wise product between a filter and the feature map, but only take summation channel-wise. That is, I have a k*k*c filter, for each sliding widow, by summing only channel-wise, I get a k*k resulting map. The total result is a k*k*n feature map, where n is the number of sliding windows.

Can I do this in the current PyTorch package? Or I need to write my own “Convolution” layer?

Thanks!

1 Like

Have a look at the group argument of nn.Conv2d.
If you want to use a set of filters for each input channel, you can set groups=in_channels

Not exactly… I want the resulting feature map to be the same spatial size as the filter, not the input. And I only have ONE (k*k*c) filter.

Say we have the filter of (k*k*c), then for each sliding window. We first take element-wise product between the filter and a (k*k*c) region in the input feature map.

Then, we only sum over the channel, which result in a (k*k) matrix (while in the real convolution, we do both spatial and channel summation, resulting in just a single value).

Next, we move the filter to the next (k*k*c) region in the input feature map, getting another (k*k) matrix.

This procedure goes on till the sliding window covers all the input feature map. We concatenate all the resulting (k*k) matrices to be a (k*k*N) tensor, where N is the number of sliding windows.

Is this doable in current PyTorch framework?

Thank you!

Thanks for the detailed explanation of the procedure.
It’s possible and I created a small example (probably you could still optimize something):

batch_size = 1
channels = 5
h, w = 12, 12
image = torch.randn(batch_size, channels, h, w) # input image

kh, kw = 3, 3 # kernel size
dh, dw = 3, 3 # stride

filt = torch.randn(channels, kh, kw) # filter (create this as nn.Parameter if you want to train it)

patches = image.unfold(2, kh, dh).unfold(3, kw, dw)
print(patches.shape)
> torch.Size([1, 5, 4, 4, 3, 3]) # batch_size, channels, h_windows, w_windows, kh, kw

patches = patches.contiguous().view(batch_size, channels, -1, kh, kw)
print(patches.shape) 
> torch.Size([1, 5, 16, 3, 3]) # batch_size, channels, windows, kh, kw

# Now we have to shift the windows into the batch dimension.
# Maybe there is another way without .permute, but this should work
patches = patches.permute(0, 2, 1, 3, 4)
patches = patches.view(-1, channels, kh, kw)
print(patches.shape)
> torch.Size([16, 5, 3, 3]) # windows * batch_size, channels, kh, kw

# Now we can use our filter and sum over the channels
patches = patches * filt
patches = patches.sum(1)
print(patches.shape)
> torch.Size([16, 3, 3]) # batch_size * windows, kh ,kw

I think this should do it.
The shapes are a bit complicated so I tried to comment all steps.
Let me know, if that’s your use case.

6 Likes

Sorry for pulling at this old thread, but can you explain how this is done…

Also, does implementing this whole in a forward function of an nn.Module class ensures proper flow of gradients and proper update of weights?

You could wrap the filt in an nn.Parameter, which would create gradients in it after a backward() call:

filt = nn.Parameter(torch.randn(channels, kh, kw))

patches = image.unfold(...)
...

patches.sum().backward()
print(filt.grad)

This should be the case. The small example shows, that the computation graph is properly created.

1 Like

Just for the sake of reassurance, this would generate random filters, which are different for different channels, which is not what happens in nn.conv2d , but is done here for demonstration?

Hey @ptrblck, just for the sake of knowledge, performing a 1*1 convolutional (Network in network ) over a feature map would do channel-wise summation? Right. If not can you elaborate please.

nn.Conv2d initializes the weight randomly in the shape [out_channels, in_channels, height, width].
In my example filt is a single filter, so out_channels would correspond to 1.

Non-grouped convolution all perform a channel-size summation as the last step.
The 1x1 kernel would multiply each channel “pixel” with the corresponding weight and apply the summation afterwards.

1 Like

Initializing convolutions in that effect with out_channels=1 would mean that random values would be set up per in_channel row. In other words, we would see that different randomized filters are being convolved with different input channels.

Like Filter A is used in channel A and Filter B is used in channel B and all of them are summed up. On the other hand, having out_channels as 1 means that we are using a single filter for all channels and are adding them up channel-wise.
If i use:

out_channels=1
in_channels=2
height=3
width=3
print(torch.rand([out_channels,in_channels,height,width]))

This gives me

tensor([[[[0.8728, 0.5557, 0.4790],
          [0.8028, 0.7029, 0.3034],
          [0.6854, 0.5710, 0.0458]],

         [[0.7219, 0.7814, 0.6355],
          [0.2254, 0.4197, 0.9144],
          [0.7522, 0.8480, 0.2152]]]])

As you see two filters are generated and they will be used in the two channels. But as out_channels=1 we should have a single filter which is convolved with the two channels and the outputs of those convolutions are added to create a single filter.

This however seems to give the required filter:

print(torch.rand([out_channels,height,width]).unsqueeze(1).repeat(1, in_channels, 1,1))

The output:

tensor([[[[0.4394, 0.8586, 0.6954],
          [0.2975, 0.1665, 0.0886],
          [0.5199, 0.8727, 0.7965]],

         [[0.4394, 0.8586, 0.6954],
          [0.2975, 0.1665, 0.0886],
          [0.5199, 0.8727, 0.7965]]]])

We might be using a different terminology, since I see

out_channels=1
in_channels=2
height=3
width=3
print(torch.rand([out_channels,in_channels,height,width]))

as a single filter with multiple input channels.
What you’ve described is how each channel of this single filter is interacting with the channels of the input activation. These are generally not considered to be different filters.

CS231n - Convolution layer gives a good description of the kernel shapes and how they are used.

1 Like

Okay I’m terribly sorry. It seems i was gravely mistaken. Anyone walking down my path, please look here for an in-depth analysis. Thanks @ptrblck I would probably have carried this mistake long enough to make more serious ones!

No need to be sorry and I’m glad we figured it out. :slight_smile:
If you’re coming from the “classical” computer vision domain, filters with multiple channels sound a bit weird at the beginning.

Yeah exactly! I started out with Image processing and there they had filters with single channels which is why i got thrown off and subconsciously started believing that the same single channel filter is multiplied with all the input channels.

Hi, sorry to disturb you.
Overview: I want to apply convolution operation only on a specific portion of the image ( some ROI)
Summary: let’s say the original image is:
lrsdO
and the image removing unwanted area is
lrsdM11
I want to run the detection network on the modified image with convolution operation applied on only a specific part to actually speed up the network. Can you help me with this?
Or source code of how pytorch implemented convolutions so that I can write my own convolution operation.
Thank you

You could implement your own custom convolution using unfold (the docs give an example code snippet for it). To do so you could mask the operation, which would create your desired operation, but wouldn’t yield any speedup. I’m also unsure how you would speedup access and processing of such a “random” pattern.

we get all the windows using unfold right?. So, I will try to remove the windows whose center is not in the interested region so that in matrix multiplication I will have a matrix of lesser size. which leads to an increase in the speedup. If I am wrong please correct me.

I have written the below code for normal convolution from scratch using previous discussions in this thread. And I am thinking if I am able to decrease the number of windows created then I would be able to decrease the running time of matrix multiplication. Please correct me if I am wrong

import cv2
import matplotlib.pyplot as plt 
%matplotlib inline
import numpy as np
import torch
import torch.nn as nn
class Myconv2d(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, weight, bias=None):
        ctx.save_for_backward(input, weight, bias)
        
        #required variables
        batch_size, channels , h, w = input.shape 
        kh, kw = weight.shape # kernel size
        dh, dw = 1, 1 # stride
        ph, pw = 0, 0 # padding
        
        #create all windows, final size of patches : (batch_size*windows per batch , kh*kw*channels)
        patches = image.unfold(2, kh, dh).unfold(3, kw, dw)
        patches = patches.contiguous().view(batch_size, channels, -1, kh, kw)
        patches = patches.permute(0, 2, 1, 3, 4)
        patches = patches.view(-1, channels, kh, kw)
        windows = patches.shape[0]
        patches = torch.reshape(patches, (windows,-1))
        
        #flatten the weight
        filters = weight.contiguous().view((-1,))
        
        #matrix multiplication
        patches = filters @ patches.T
        
        #output size calculation
        oh = (int)((h-kh+2*ph)/dh+1)
        ow = (int)((w-kw+2*pw)/dw+1)
        
        #back to normal image view
        patches = patches.reshape(batch_size,channels, oh,ow)
        
        return patches
    
    @staticmethod
    def backward(ctx, grad_output):
        input, weight, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0).squeeze(0)

        return grad_input, grad_weight, grad_bias

#model
class Myconv2D(torch.nn.Module):
    def __init__(self, input_features, output_features, bias = None):
        super(Myconv2D, self).__init__()
        self.input_features = input_features
        self.output_features = output_features
        
        sobel = torch.tensor([ [  1.0,  6,  15,  20,  15,  6,  1],
                                [  2,  12, 30,  40, 30, 12,  2],
                                [  3, 18, 45,  60, 45, 18,  3],
                                [  0,  0,  0,   0,  0,  0,  0],
                                [ -3, -18,-45,-60,-45,-18, -3],
                                [ -2, -12, -30,-40,-30,-12, -2],
                                [ -1, -6, -15, -20, -15, -6, -1] ])
        self.weight = nn.Parameter(sobel)
        if bias:
            self.bias = nn.Parameter(torch.empty(output_features))
        else:
            self.register_parameter('bias', None)
        
        if self.bias is not None:
            self.bias.data.uniform_(-0.1, 0.1)
            
    def forward(self, input):
        return Myconv2d.apply(input, self.weight, self.bias)
    
if __name__ == "__main__":
    path = "../input/motorroll-tracking/0001.jpg"
    img = cv2.imread(path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    batch_size = 1
    channels = 1
    h, w = img.shape
    image = torch.from_numpy(img/255.0).float() # input image
    image = image.reshape(batch_size,channels,h,w)
    mymodel = Myconv2D(1,1)
    output = mymodel.forward(image)
    for i in range(batch_size):
        plt.imshow(output[i,0,:,:].detach().numpy(),cmap='gray')
1 Like

Yes, if you are able to reduce the matrix size used in the matmul, you would save computations and could thus speed it up. Depending on the actual layout, the logic to remove parts of the matrix etc. you might be able to see a performance gain.

Hi, I am also trying something similar. Did this work for you? Were you able to boost the performance? If yes, could you please share the code here? It would be really helpful for my work. Thanks in advance.