Sum over various subsets of a tensor

This may have already been addressed, but I did some googling and couldn’t find a solution. I’d like to compute various sums from unequal sized subsets of a given tensor (or more precisely from a column vector) where the summing index boundaries are defined by a list (or tensor) and then to have the operation return a tensor of these sums (without using a for loop)

torch.split does almost exactly what I want but it returns a list, for example

a=torch.tensor([[1,2,3,4,5,6,7,8]])

print(torch.split(a,[2,2,4],dim=1))
returns the following lists
(tensor([[1, 2]]), tensor([[3, 4]]), tensor([[5, 6, 7, 8]]))

what I want is a tensor whose elements are the sums of those individual lists
i.e tensor([[3,7,26]]) .
Is it possible to do this in a vectorized way using torch operations? also I’d like to do the same thing but with a product over the subsets and return a tensor of the products.

Much thanks in advance!

4 Likes

I’m still looking for a better solution but if someone is having a similar problem, I sort of figured out how to do it with the help of ideas in the following link:

This isn’t the most efficient method because a lot of the terms in cumsum (and cumprod) are unnecessary, if someone finds a more efficient way to do this then please post here but the following seems to work at least for the summation: (NOTE: this has problems with cumprod, for example if the array is long then the cumulative product can blow up to INF or if there are any zeros in the array, then all subsequent products are zero)

PYTORCH SOLUTION

FOR THE SUM
a=torch.tensor([[1,2,3,4,5,6,7,8]])
b=a.cumsum(1) #cumulative sum over row
c=b.gather(1, torch.tensor([[1,3,7]])) #select relevant terms
d=torch.cat( (torch.tensor([[0]]), b.gather(1, torch.tensor([[1,3]]))),1) #select relevant terms
print(c,d,c-d)

returns

tensor([[ 3, 10, 36]]) tensor([[ 0, 3, 10]]) tensor([[ 3, 7, 26]])

FOR THE PRODUCT
a=torch.tensor([[1,2,3,4,5,6,7,8]])
b=a.cumprod(1) #cumulative sum over row
c=b.gather(1, torch.tensor([[1,3,7]])) #select relevant terms
d=torch.cat( (torch.tensor([[1]]), b.gather(1, torch.tensor([[1,3]]))),1) #select relevant terms
print(c,d,c/d)

returns
tensor([[ 2, 24, 40320]]) tensor([[ 1, 2, 24]]) tensor([[ 2, 12, 1680]])

NUMPY SOLUTION
a = np.array([[1,2,3,4,5,6,7,8]])
b=np.add.reduceat(a, [0,2,4], axis=1)
print(b,type(b),b.shape)

returns

[[ 3 7 26]] <class ‘numpy.ndarray’> (1, 3)

Thanks for the answer @James_Hickman!

I am also facing this problem right now, and am looking for a more efficient solution since I am dealing with large tensors.

Bumping this post in case anyone knows of a better solution now, because of the improvements in Pytorch over the last two years.

@Remorax Did you find a more efficient solution? I have the same problem.

Thank you.

Perhaps this is helpful.

Based on the solution of @James_Hickman :

a = torch.tensor([[1,2,3,4,5,6,7,8]])
b = a.cumsum(1)
c = b.gather(1, torch.tensor([[1,3,7]])) #select relevant terms
c = torch.cat([torch.zeros(1, 1), c], dim=-1) #start the sum with zeros
res = c.unsqueeze(2) - c.unsqueeze(1)
torch.diagonal(res, offset=-1, dim1=1, dim2=2)

returns

tensor([[ 3.,  7., 26.]])

The scatter_add function in the pytorch-scatter package seems to do exactly what you want. See Scatter Add — pytorch_scatter 1.3.0 documentation

1 Like

this is helpful to me. thank you very much!

Finally found a nice solution to solve this problem, which consists in basically multiplying the input tensor by a one-hot corresponding to the indices of the elements.

Assume you have a tensor input of size Nxd (i.e. N d-dimensional elements) and a tensor encoding_indices of size N s.t. encoding_indices[i] is the id of the subset in which input[i] is. Let num_codes be the number of subsets.

Then, the following script computes the sum of all vectors aggregated per subset:

one_hot = F.one_hot(encoding_indices, num_classes=num_codes).float()
sum_per_subset = torch.mm(one_hot.T, inputs)

Since this script implies a matrix product of size nxN times Nxd implying a very sparse matrix, this script may be optimized by using a sparse tensor to encode the one-hot vector:

num_inputs = inputs.size(0)
indices = torch.stack((
    encoding_indices,
    torch.arange(num_inputs, device=encoding_indices.device)
))
values = torch.ones_like(encoding_indices, dtype=torch.float)
one_hot = torch.sparse_coo_tensor(indices, values, size=(num_codes, num_inputs))
sums_per_subset = torch.mm(one_hot, inputs)

Here is a performance comparison on CPU and GPU:

cpu
f1: 1.087107 s  # naive for loop
f2: 0.087390 s  # dense matrix multiplication
f3: 0.016435 s  # sparse matrix multiplication

cuda:0
f1: 0.117819 s
f2: 0.010617 s
f3: 0.002403 s

Results obtained running the following script:

import timeit

import torch
import torch.nn.functional as F


def f1(inputs, encoding_indices, num_codes):
    sum_inputs = torch.zeros(num_codes, inputs.size(1), device=inputs.device)
    for i in range(num_codes):
        sum_inputs[i] = inputs[encoding_indices == i].sum(dim=0)
    return sum_inputs


def f2(inputs, encoding_indices, num_codes):
    one_hot = F.one_hot(encoding_indices, num_classes=num_codes).float()
    return torch.mm(one_hot.T, inputs)


def f3(inputs, encoding_indices, num_codes):
    num_inputs = inputs.size(0)
    indices = torch.stack((
        encoding_indices,
        torch.arange(num_inputs, device=encoding_indices.device)
    ))
    values = torch.ones_like(encoding_indices, dtype=torch.float)
    one_hot = torch.sparse_coo_tensor(indices, values, size=(num_codes, num_inputs))
    return torch.mm(one_hot, inputs)


num_inputs = 32456
code_dim = 67
num_codes = 511

inputs = torch.rand(num_inputs, code_dim)
encoding_indices = torch.randint(0, num_codes, (num_inputs,))

for d in ["cpu", "cuda:0"]:
    device = torch.device(d)
    print(d)

    inputs = inputs.to(device)
    encoding_indices = encoding_indices.to(device)

    a1 = f1(inputs, encoding_indices, num_codes)
    a2 = f3(inputs, encoding_indices, num_codes)

    assert torch.allclose(a1, a2)

    for i in range(1, 4):
        print("f%d: %.6f s" % (
            i,
            timeit.timeit(
                f"f{i:d}(inputs, encoding_indices, num_codes)",
                number=100,
                globals=globals()
            ) / 100
        ))

I especially use this method for my implementation of EMA in VQ-VAE, since centroids are updated following the average of the inputs which are closest