Constructing a tensor by taking mean over index list of another tensor

I’m looking for a faster method to do a specific operation (i.e., mean of specified indices) on a given tensor. To be more specific, let’s say I have a tensor A with the shape of [B, N, D], with N being the batch size which is 1 in my case, N being the number of nodes, and D being the feature dimension. I also have a list I which is a list of lists, each of which contains index values. For instance, A could be of shape [1, 2048, 768] and I can be I = [ [0, 1], [2], [3, 4, 5], ... ].

My goal is to create a resulting tensor Z which aggregates the mean values from A based off the indices provided in I. For now, I’m doing it as below, but it seems to be much time-consuming to produce the resulting Z:

idx_mapping = {k : v for k, v in zip(I, [s for s in range(len(I))])}

Z = torch.zeros(A.size(0), len(I), A.size(2))

for k, v in idx_mapping.items():
        
    # indices (List) which we want to take mean from tensor A
    mean_indices = k
    
    # mapped idx in the resulting vector Z
    embedded_into = v 
    
    # do mean on mean_indices of tensor A
    token_embed = torch.index_select(A, 1, torch.tensor(eval(f'[{mean_indices.replace("-", ",")}]')).mean(dim=1).unsqueeze(0))
        
    # replace the calculated mean in embedded_into's index of tensor Z
    Z = torch.cat((Z[:, :embedded_into, :], token_embed, Z[:, embedded_into+1: , :]), dim=1)

For reference, idx_mapping is a dictionary that maps the indices of tensor A to resulting tensor Z. I have constructed its keys from list I and values are incremented one by one. For instance, idx_mapping can be: {'0': 0, '1-2': 1, '3-4-5': 2, ...}

Now the question is that what is the best (and possibly quickest) way to this operation in Pytorch?

scatter_add_ should work after creating the necessary indices:

x = torch.randn(1, 6, 768)
I = [[0, 1], [2], [3, 4, 5]]

tmp = torch.arange(len(I))
idx = torch.repeat_interleave(tmp, torch.tensor([len(a) for a in I]))
idx = idx.unsqueeze(0).unsqueeze(2).expand_as(x)
res = torch.zeros(x.size(0), idx.max()+1, x.size(2))
res.scatter_add_(1, idx, x)
res = res / torch.tensor([len(a) for a in I])[None, :, None]

reference = torch.stack([x[:, i].mean(1) for i in I], dim=1)

print((res - reference).abs().max())
# > tensor(0.)