Convolution that only take channel-wise summation?

Thanks for the detailed explanation of the procedure.
It’s possible and I created a small example (probably you could still optimize something):

batch_size = 1
channels = 5
h, w = 12, 12
image = torch.randn(batch_size, channels, h, w) # input image

kh, kw = 3, 3 # kernel size
dh, dw = 3, 3 # stride

filt = torch.randn(channels, kh, kw) # filter (create this as nn.Parameter if you want to train it)

patches = image.unfold(2, kh, dh).unfold(3, kw, dw)
print(patches.shape)
> torch.Size([1, 5, 4, 4, 3, 3]) # batch_size, channels, h_windows, w_windows, kh, kw

patches = patches.contiguous().view(batch_size, channels, -1, kh, kw)
print(patches.shape) 
> torch.Size([1, 5, 16, 3, 3]) # batch_size, channels, windows, kh, kw

# Now we have to shift the windows into the batch dimension.
# Maybe there is another way without .permute, but this should work
patches = patches.permute(0, 2, 1, 3, 4)
patches = patches.view(-1, channels, kh, kw)
print(patches.shape)
> torch.Size([16, 5, 3, 3]) # windows * batch_size, channels, kh, kw

# Now we can use our filter and sum over the channels
patches = patches * filt
patches = patches.sum(1)
print(patches.shape)
> torch.Size([16, 3, 3]) # batch_size * windows, kh ,kw

I think this should do it.
The shapes are a bit complicated so I tried to comment all steps.
Let me know, if that’s your use case.

6 Likes