How to implement this (essentially calculating average of vectors)

batch_size, d= 8, 128
n, N= 100, 10
feature=torch.rand(batch_size, n, d)
belong=torch.randint(0, N, (batch_size,n))
output=torch.zeros(batch_size, N, d)
for i in range(batch_size):
for j in range(N):
output[i][j]=feature[i][belong[i]==j].mean(0)

with high-performance, CUDA-friendly and autograd-friendly pytorch function?
Grateful to your help
P.S. I tried torch.index_add and it performs badly for CUDA

The core issue is that feature[i][belong[i]==j] has a different
shape for different values of i and j. Therefore if you try to build a
â€śtensorâ€ť whose â€śslicesâ€ť are given by the above expression, you will
end up with a â€śragged tensorâ€ť (that is, a tensor whose slices have
differing shapes). And pytorch doesnâ€™t support ragged tensors.

Once you use such boolean â€śindexingâ€ť (the cause of the differing
shapes), there is no way for you to complete your calculation using
pure tensor operations.

Instead, you should convert the boolean expression belong[i]==j
to a 0, 1 numerical mask and use element-wise multiplication to mask
the desired elements of feature[i]. (The boolean tensor will be
automatically cast to a numerical tensor in the context we use it, so
we donâ€™t have to do it explicitly.)

Also, because you are calculating mean() rather than sum(), it proves
convenient to use a 0, 1 / nTrue mask, rather than a 0, 1 mask.

Lastly, there is no complication in using tensor operations to perform
the calculation on all batch elements (all slices along the batch_size
dimension) at the same time.

Here is an illustrative script:

import torch
print (torch.__version__)
_ = torch.manual_seed (2021)
batch_size, d= 8, 128
n, N= 100, 10
feature=torch.rand(batch_size, n, d)
belong=torch.randint(0, N, (batch_size,n))
# boolean "indexing" with loop
output=torch.zeros(batch_size, N, d)
for i in range(batch_size):
for j in range(N):
output[i][j]=feature[i][belong[i]==j].mean(0)
# loop-free tensor method
mask = belong.unsqueeze (-1) == torch.arange (N)
mask = mask / mask.sum (1).unsqueeze (1) # to compute mean() rather than sum()
outputB = torch.einsum ('ijk, ijl -> ikl', mask, feature)
print ('outputB.allclose (output) =', outputB.allclose (output))