K-nearest neighbor using index select

I’m working on a fun problem where I am looking at nearest neighbors between two 3-dimensional point sets. So for every point in point cloud p1 I’m finding the mean center of the k nearest points in opoint cloud p2. Here is working code that does just that:

def calc_k_nearest(p1, p2, k=10):
    adj = p1[:, None,...] - p2[None, ...]
    adj = adj.pow(2).sum(-1)
    top_k = adj.topk(k=k, dim=1, largest=False)[1]
    
    ls = []
    for i in range(top_k.shape[0]):
        tmp = torch.index_select(p2, dim=0, index=top_k[i])
        ls.append(tmp.mean(0))
    
    near_points = torch.stack(ls)
    mag = near_points - p1
    return near_points, mag

What I don’t like though is the for loop. It would be nice to have index_select return a two-dimensional matrix instead of iterating through every one. I’m not sure what sort of broadcasting logic would exist for something like this. Is there a way to do this with index_select or a different function?