How to multiply each channel by the sum of its pixel values?

I have the standard NxCxWxH feature map. And I want to multiply each channel by the sum of all pixel values of each channel. How can I write a function that can do such a thing?

If I understood correctly, you want to add every value per each channel. Meaning, that for an RGB image you would add all of the R values together, all of G is another sum and B would be a third sum. Each of these results is then multiplied to the original values, such that if the sum of R=50, then you multiply every value in the R channel by 50.

If this is correct, then the following code should do what you want.

If that is not what you want, please let me know.

import torch

img = torch.rand(1, 3, 10, 10)

def mult_by_channel_sum(img):
    sum_per_channel = img.sum(dim=3).sum(dim=2).unsqueeze(2).unsqueeze(3)
    return img * sum_per_channel

print(mult_by_channel_sum(img).shape)
# Output:
torch.Size([1, 3, 10, 10])

It seems what I want but will such a thing work on different batch sizes other than 1?

Yes, this should work for any batch size.

This should also work for any number of channels, any H and W.

Also if you have more dimensions, and you only care that it is the last two dimensions, then you could change it to this.

def mult_by_channel_sum(img):
    sum_per_channel = img.sum(dim=-1).sum(dim=-1).unsqueeze(-1).unsqueeze(-1)
    return img * sum_per_channel