# How to sum every k channels for a CNN feature map

Hey guys, I want to sum every k consecutive channels of a variable together. Assume that the input before summing is with the shape like NxCxWxH, and the output after the summation should be Nx(C/k)xWxH. I implemented the method using the code below. the function sum_up is called in the network’s forward function, but it seems to be extremely inefficient. Is there any better way to implement this in pytorch?

``````def sum_up(input, bottle_neck=2):
output= []
for ch in xrange(0, input.data.shape, bottle_neck):
output.append(torch.sum(input[:, ch:ch+bottle_neck, :, :], 1))
``````
1 Like

You can split the channel dimension with a view/reshape and sum

``````n,c,h,w = input.size()
res = input.reshape(n, c//k, k, h, w).sum(2)
``````

Best regards

Thomas

4 Likes

Thanks! This works for me!

Btw, if you are trying to do group sum, it will be available in the next release soon.

1 Like

Could you also tell some efficient way of adding channels in an interleaved fashion? For example if there are 6 channels [0-5], I wish to add 0 th and 3rd, 1st and 4th, 2nd and 5th?

I have not tried it, but I have a vague idea.

``````for i in range(No_of_channels):
new_channel = old_feature_map[i]+old_feature_map[i+1]
# now concatenate the new channels one after another to get a new feature map.

``````

I have implemented this but it seems to be inefficient as we are using for loop. I am actually looking for some efficient way to do that. Anyways thanks for your answer.

This should work with advanced indexing like this (haven’t tested it yet):

``````new_tensor = tensor[:, ::2] + tensor[:, 1::2]
``````

Edit: sorry, you would of course have to specify the middle channel like this:

``````middle_channel = tensor.size(1)//2
new_tensor = tensor[:, :middle_channel] + tensor[:, middle_channel:]
``````

Sorry for the first part, I misunderstood your goal.

Did you try it? Does it work?

Not yet. Will try it out tomorrow.

Well, you can extend the trick above, if you have `c` channels, c divisible by k*l:

``````n,c,h,w = input.size()
res = input.reshape(n, c//k//l, k, l, l, h, w).sum(2).view(n, c//k, h, w)
``````

This will sum k channels from an interleaving grid of spacing l.

Best regards

Thomas

1 Like