Efficient summation over distance tensor blocks

I have an N x N tensor, D (this is a symmetric distance matrix) the rows and columns of D are arranged in contiguous ‘groups’. The grouping index for each row/col in D is given as a 1D tensor, B where B[i] is the group (int) assigned to row i of D. I would like to efficiently take the sum of all pairs of slices of D induced by the grouping B. For example, with a distance matrix between 5 objects in two groups:

D = torch.tensor([[0., 1., 2., 3., 4.],
                  [1., 0., 1., 2., 3.],
                  [2., 1., 0., 1., 2.],
                  [3., 2., 1., 0., 1.],
                  [4., 3., 2., 1., 0.]])

# two groups, one of size 2 and one of size 3
B = torch.tensor([0, 0, 1, 1, 1])

# along each dim, the first group contains index 0 and 1
# and the second group has indices 2, 3, 4
# this makes 3 unique blocks: 
# D[0:2:,][:,0:2], D[0:2, :][:,2:5], D[2:5,:][:,2:5]

>>> result
torch.tensor([2, 15, 8])

The most efficient would seem to be to make use of scatter_add_ but would appreciate an example of how to apply it here, or other guidance.

So scatter_add (and similar operations) when implemented with atomicAdd get slow when you have many “collisions”, i.e. GPU threads trying to write to the same address.
When you have contiguous blocks like these, it is feasible to write a kernel that uses that to be much more efficient, but likely the effort is quite high.
Personally, I’d probably double check what the performance is of scatter_add vs. calling .sum() on the blocks.

Best regards


Thanks for this, I’ll definitely do a benchmarking. I would also appreciate an example of how to use scatter_add in this situation since I am quite new to this function…