I have a tensor of size (n1, n2, n3) and a lst of size n2 with integer values in [0, 1, ..., n2-1] as partition (or cluster) assignments. What I want is an output of size (n1, len(set(lst)), n3) where the second elements are the mean of those belong to the same partition (cluster). For instance, imagine we have tensor a of size (50, 5, 16) and we have lst = [0, 1, 1, 1, 2]. I want an output b of size (50,3,16) where b[:, 0, :] = a[:, 0, :], b[:, 1, :] = mean of (a[:, 1, :], a[:, 2, :], a[:, 3, :]) and b[:, 2, :] = a[:, 4, :].
What would be an efficient way of implementing this in PyTorch? I appreciate your time.