I have a tensor of size (n1, n2, n3)
and a lst
of size n2
with integer values in [0, 1, ..., n2-1]
as partition (or cluster) assignments. What I want is an output of size (n1, len(set(lst)), n3)
where the second elements are the mean of those belong to the same partition (cluster). For instance, imagine we have tensor a
of size (50, 5, 16)
and we have lst = [0, 1, 1, 1, 2]
. I want an output b
of size (50,3,16)
where b[:, 0, :] = a[:, 0, :]
, b[:, 1, :] = mean of (a[:, 1, :], a[:, 2, :], a[:, 3, :])
and b[:, 2, :] = a[:, 4, :]
.
What would be an efficient way of implementing this in PyTorch? I appreciate your time.