Constructing a tensor by taking mean over index list of another tensor

scatter_add_ should work after creating the necessary indices:

x = torch.randn(1, 6, 768)
I = [[0, 1], [2], [3, 4, 5]]

tmp = torch.arange(len(I))
idx = torch.repeat_interleave(tmp, torch.tensor([len(a) for a in I]))
idx = idx.unsqueeze(0).unsqueeze(2).expand_as(x)
res = torch.zeros(x.size(0), idx.max()+1, x.size(2))
res.scatter_add_(1, idx, x)
res = res / torch.tensor([len(a) for a in I])[None, :, None]

reference = torch.stack([x[:, i].mean(1) for i in I], dim=1)

print((res - reference).abs().max())
# > tensor(0.)
1 Like