I have a tensor of size `(n1, n2, n3)`

and a `lst`

of size `n2`

with integer values in `[0, 1, ..., n2-1]`

as partition (or cluster) assignments. What I want is an output of size `(n1, len(set(lst)), n3)`

where the second elements are the mean of those belong to the same partition (cluster). For instance, imagine we have tensor `a`

of size `(50, 5, 16)`

and we have `lst = [0, 1, 1, 1, 2]`

. I want an output `b`

of size `(50,3,16)`

where `b[:, 0, :] = a[:, 0, :]`

, `b[:, 1, :] = mean of (a[:, 1, :], a[:, 2, :], a[:, 3, :])`

and `b[:, 2, :] = a[:, 4, :]`

.

What would be an efficient way of implementing this in PyTorch? I appreciate your time.