I’m working on a fun problem where I am looking at nearest neighbors between two 3-dimensional point sets. So for every point in point cloud p1
I’m finding the mean center of the k nearest points in opoint cloud p2
. Here is working code that does just that:
def calc_k_nearest(p1, p2, k=10):
adj = p1[:, None,...] - p2[None, ...]
adj = adj.pow(2).sum(-1)
top_k = adj.topk(k=k, dim=1, largest=False)[1]
ls = []
for i in range(top_k.shape[0]):
tmp = torch.index_select(p2, dim=0, index=top_k[i])
ls.append(tmp.mean(0))
near_points = torch.stack(ls)
mag = near_points - p1
return near_points, mag
What I don’t like though is the for loop. It would be nice to have index_select return a two-dimensional matrix instead of iterating through every one. I’m not sure what sort of broadcasting logic would exist for something like this. Is there a way to do this with index_select or a different function?