I have an N x N tensor, D (this is a symmetric distance matrix) the rows and columns of D are arranged in contiguous ‘groups’. The grouping index for each row/col in D is given as a 1D tensor, B where B[i] is the group (int) assigned to row i of D. I would like to efficiently take the sum of all pairs of slices of D induced by the grouping B. For example, with a distance matrix between 5 objects in two groups:
D = torch.tensor([[0., 1., 2., 3., 4.],
[1., 0., 1., 2., 3.],
[2., 1., 0., 1., 2.],
[3., 2., 1., 0., 1.],
[4., 3., 2., 1., 0.]])
# two groups, one of size 2 and one of size 3
B = torch.tensor([0, 0, 1, 1, 1])
# along each dim, the first group contains index 0 and 1
# and the second group has indices 2, 3, 4
# this makes 3 unique blocks:
# D[0:2:,][:,0:2], D[0:2, :][:,2:5], D[2:5,:][:,2:5]
>>> result
torch.tensor([2, 15, 8])
The most efficient would seem to be to make use of scatter_add_
but would appreciate an example of how to apply it here, or other guidance.