I would like to be able to compute a reduction over different sized blocks within a larger matrix efficiently. I am working with point set data, so each individual data point is a 2d tensor that can have a varying number of rows, but they all have the same feature dimension (ie all are
[n, d] where
d is constant).
Suppose I have 3 tensors:
A = torch.tensor([1., 2., 3.]).reshape(3,1).repeat(1,3) # 3x3 B = torch.tensor([4., 5.]).reshape(2,1).repeat(1,3) # 2x3 C = torch.tensor([6., 7., 8., 9.]).reshape(4,1).repeat(1,3) # 4x3
The number of items in each point set can dramatically differ, so having padded rows is not preferred (and this is my current implementation), so what I’m doing is stacking each individual point set, since they have the same feature dimension:
stacked = torch.vstack([A,B,C])
Now here is my problem. I need to compute set-set pairwise distances as a function of item-item pairwise distances. The item-item distances are easy since I can just do:
# NxN where N is the total number of items, ie 9 in this ex item_dists = torch.cdist(stacked, stacked)
But the item-item distances now have blocks that correspond to all possible pairwise distance matrices when considering each point set (A, B, C). For example, the A-A block is located at positions
[0:3, 0:3] since its items are the first compared. Then the A-B block is located at
[0:3, 3:5] since B has the next two items in
The problem I am facing is how to efficiently perform a reduction over these blocks of different sizes. The simplest solution is to use a for loop that can extract these blocks, and then do the reduction like this:
sizes = [3, 2, 4] # [0, 3, 5, 9] ptr = list(itertools.accumulate(sizes, initial=0)) batch_size = len(sizes) # all combinations: [(A, B), (A, C), (B, C)] # ie: [(0, 1), (0, 2), (1, 2)] indices = list( itertools.combinations(range(batch_size), r=2) ) item_dists = torch.cdist(stacked, stacked) for i, j in indices: x_start = ptr[i] x_end = ptr[i+1] y_start = ptr[j] y_end = ptr[j+1] block = item_dists[x_start:x_end, y_start:y_end] # block reduction
However, this loop is prohibitively slow. Any thoughts for a better implementation, or does the loop just need to come out of python?