How to sum every k channels for a CNN feature map

Hey guys, I want to sum every k consecutive channels of a variable together. Assume that the input before summing is with the shape like NxCxWxH, and the output after the summation should be Nx(C/k)xWxH. I implemented the method using the code below. the function sum_up is called in the network’s forward function, but it seems to be extremely inefficient. Is there any better way to implement this in pytorch?

def sum_up(input, bottle_neck=2):
      output= []
      for ch in xrange(0, input.data.shape[1], bottle_neck):
           output.append(torch.sum(input[:, ch:ch+bottle_neck, :, :], 1))
      return torch.stack(output).permute(1, 0, 2, 3).cuda()
1 Like

You can split the channel dimension with a view/reshape and sum

n,c,h,w = input.size()
res = input.reshape(n, c//k, k, h, w).sum(2)

Best regards

Thomas

4 Likes

Thanks! This works for me!

Btw, if you are trying to do group sum, it will be available in the next release soon.

1 Like

Could you also tell some efficient way of adding channels in an interleaved fashion? For example if there are 6 channels [0-5], I wish to add 0 th and 3rd, 1st and 4th, 2nd and 5th?

I have not tried it, but I have a vague idea.

for i in range(No_of_channels):
    new_channel = old_feature_map[i]+old_feature_map[i+1]
   # now concatenate the new channels one after another to get a new feature map.


I have implemented this but it seems to be inefficient as we are using for loop. I am actually looking for some efficient way to do that. Anyways thanks for your answer.

This should work with advanced indexing like this (haven’t tested it yet):

new_tensor = tensor[:, ::2] + tensor[:, 1::2]

Edit: sorry, you would of course have to specify the middle channel like this:

middle_channel = tensor.size(1)//2
new_tensor = tensor[:, :middle_channel] + tensor[:, middle_channel:] 

Sorry for the first part, I misunderstood your goal.

Thanks @justusschock for the answer.

Did you try it? Does it work?

Not yet. Will try it out tomorrow.

Well, you can extend the trick above, if you have c channels, c divisible by k*l:

n,c,h,w = input.size()
res = input.reshape(n, c//k//l, k, l, l, h, w).sum(2).view(n, c//k, h, w)

This will sum k channels from an interleaving grid of spacing l.

Best regards

Thomas

1 Like

Thanks for the answer @tom

Did group sum release yet? Thanks

1 Like