num_channels = 10
x = torch.randn(2, num_channels, 4, 5)
result = torch.sum(x, dim=[0, 2, 3])
torch.sum accepts a parameter dim which can be a list. If you pass a list, it will reduce the tensor over all the indices in dim. For instance,
x = torch.randn(2, 4, 5)
x[0,:,:] = 1
x[1,:,:] = 2
result = torch.sum(x, dim=[1,2]) # tensor([20., 40.])
And explanation would be we have 4*5 cells for each channel (2 in this case) so as I set dim=0 all equal to 1 and dim=1 all equal to 2, so sum over all dims except channel will reduce it in that way.
Thanks a lot and It saved half of the time! I also need to optimize per-channel torch.max() and torch.max doesn’t have such features you mentioned, any suggestion for per-channel torch.max?