I have the standard NxCxWxH feature map. And I want to multiply each channel by the sum of all pixel values of each channel. How can I write a function that can do such a thing?
If I understood correctly, you want to add every value per each channel. Meaning, that for an RGB image you would add all of the R values together, all of G is another sum and B would be a third sum. Each of these results is then multiplied to the original values, such that if the sum of R=50, then you multiply every value in the R channel by 50.
If this is correct, then the following code should do what you want.
If that is not what you want, please let me know.
import torch img = torch.rand(1, 3, 10, 10) def mult_by_channel_sum(img): sum_per_channel = img.sum(dim=3).sum(dim=2).unsqueeze(2).unsqueeze(3) return img * sum_per_channel print(mult_by_channel_sum(img).shape)
# Output: torch.Size([1, 3, 10, 10])
It seems what I want but will such a thing work on different batch sizes other than 1?
Yes, this should work for any batch size.
This should also work for any number of channels, any H and W.
Also if you have more dimensions, and you only care that it is the last two dimensions, then you could change it to this.
def mult_by_channel_sum(img): sum_per_channel = img.sum(dim=-1).sum(dim=-1).unsqueeze(-1).unsqueeze(-1) return img * sum_per_channel