How to efficiently perform averaging across predefined groups in a tensor

You could use scatter_add and a small hack to get the unique counts of your indices:

x = torch.arange(1, 7, dtype=torch.float)
idx = torch.tensor([0, 1, 0, 1, 2, 2])
idx_unique = idx.unique(sorted=True)
idx_unique_count = torch.stack([(idx==idx_u).sum() for idx_u in idx_unique])
res = torch.zeros(len(idx_unique)).scatter_add(0, idx, x)
res /= idx_unique_count.float()

I think it’s time to add a return_counts option to torch.unqiue.

1 Like