Batched addition over contiguous indices

I am trying to multiply a matrix and a vector together, but only over a certain range of indices that depends on the matrix row. I can write it with a for loop, but this is not efficient. Is there a way to vectorize/broadcast this operation in an efficient gpu-readable manner?

Here is the example and what I am trying to achieve:

N = 5
D = torch.stack([torch.cat((torch.arange(1, N), torch.zeros(1))), 
                 torch.cat((torch.zeros(2), torch.arange(2, N)))])
offset = torch.tensor([-1, 2])
i0 = torch.tensor([0, 2])
i1 = torch.tensor([N-1, N])
X = torch.arange(N).to(torch.float)

# Compute the element-wise multiplication over each row of D
Y = torch.zeros_like(X)
for k in range(2):
    Y[(i0[k]-offset[k]):(i1[k]-offset[k])] += D[k, i0[k]:i1[k]] * X[i0[k]:i1[k]]

# Expected results
print(D)
print(X)
print(Y)
# D = tensor([[1., 2., 3., 4., 0.],
#             [0., 0., 2., 3., 4.]])
# X  = tensor([0., 1., 2., 3., 4.])
# Y  = tensor([4., 9., 18., 6., 12.])

Essentially, I just want to get rid of the for loop.

[Edit: added the offset tensor, otherwise the problem is trivial]

I have managed to simplify the expression, with the following:

Y = torch.zeros_like(X)
index = torch.tensor([[1,2,3,4,0],[3,4,0,1,2]])
for k in range(2):
    Y.index_add_(0, index[k, :], D[k, :] * X)
print(Y)
# Y  = tensor([4., 9., 18., 6., 12.])

Not sure how to get rid of the for loop however.

You can avoid the loop by flattening the index and value tensors:

out = torch.zeros_like(X)
value = D * X
out.index_add_(0, index.view(-1), value.view(-1))
print(out)
# tensor([ 4.,  9., 18.,  6., 12.])
1 Like

Thanks for your answer ! This is indeed an elegent way of doing this.

However, benchmarking for larger tensors, this achieves worse speeds on GPU than my initial for loop without index_add_. Typically, I have D.size() = (M, N) and X.size() = (N) with N = 100_000_000 and M = 10. I’m guessing that this is due to the fact that index_add_ does not use the fact that I’m summing over contiguous indices. Is there an alternative batched version that would make use of this contiguous-ness in the sum?

Thank you for your help.