How to compute per-channel sum fast?

suppose the shape of x is [2, num_channels, 4, 5], and I compute per-channel summation as follows:

s = torch.zeros(num_channels)
for i in range(num_channels):
s[i] = x[:, i].sum()

It turns out this implement is too slow. Is there a faster implement?


Using built-in functions, it would be like this:

num_channels = 10
x = torch.randn(2, num_channels, 4, 5)
result = torch.sum(x, dim=[0, 2, 3])

torch.sum accepts a parameter dim which can be a list. If you pass a list, it will reduce the tensor over all the indices in dim. For instance,

x = torch.randn(2, 4, 5)
x[0,:,:] = 1
x[1,:,:] = 2
result = torch.sum(x, dim=[1,2])  # tensor([20., 40.])

And explanation would be we have 4*5 cells for each channel (2 in this case) so as I set dim=0 all equal to 1 and dim=1 all equal to 2, so sum over all dims except channel will reduce it in that way.


Thanks a lot and It saved half of the time! I also need to optimize per-channel torch.max() and torch.max doesn’t have such features you mentioned, any suggestion for per-channel torch.max?

glad it worked,
Actually, there is no built-in function for max or argmax like sum. But you can do the trick something similar to this post: