Sum / Mul over multiple axes

Maybe this is a silly question, but how can we sum over multiple dimensions in pytorch?

In numpy, np.sum() takes a axis argument which can be an int or a tuple of ints, while in pytorch, torch.sum() takes a dim argument which can take only a single int.
Say I have a tensor of size 16 x 256 x 14 x 14, and I want to sum over the third and fourth dimensions to get a tensor of size 16 x 256. In numpy, one can do np.sum(t, axis=(2, 3)), what is the pytorch equivalent?


AFAIK there isn’t such methods. So ou can only do t.sum(2).sum(2)
or t.view(16,256,-1).sum(2)


Indeed, I’ve always used the t.sum(2).sum(2), but the view() solution looks cleaner, thanks!

is that the same as tensorflow’s reduce sum?

numpy’s sum() is like tensorflow’s reduce sum. Unfortunately right now pytorch’s sum doesn’t support summing over an axis (it’s being worked on, though!) but there are ways to be creative with this.


what do u mean by that?

One way to do this is, depending on the number of dimensions your tensor has, do this by hand with for loops, after transposing the matrix so the dimension you want to reduce over is last. For example:

import torch
x = torch.randn(2, 2, 2)
#  Let's say I want to sum over the 1st dimension (0-indexed)
x = x.transpose(1, 2)
torch.Tensor([ [ z.sum() for z in y ]  for y in x])
1 Like

Hope no one minds, but I thought I’d share my “creative solution”.

def multi_max(input, axes, keepdim=False):
    Performs `torch.max` over multiple dimensions of `input`
    axes = sorted(axes)
    maxed = input
    for axis in reversed(axes):
        maxed, _ = maxed.max(axis, keepdim)
    return maxed

This style should work for sum as well.

For sum, you can just pass multiple dims. :slight_smile:

1 Like

Really? What version of PyTorch? I have v0.4.0 and neither of the following worked:

import torch
x = torch.rand(5,5,5)
torch.sum(x, (1,2))
y = x.sum((1,2))

Both complained about wrong types.

I’d like a built in way to do this (that’s why I ended up on this page, after all).

1 Like

Ah, sorry, this was only merged a week after 0.4. So you would either need to build your own or wait for 0.5.

Best regards