I’m looking for a faster method to do a specific operation (i.e., mean of specified indices) on a given tensor. To be more specific, let’s say I have a tensor A
with the shape of [B, N, D]
, with N being the batch size which is 1 in my case, N being the number of nodes, and D being the feature dimension. I also have a list I
which is a list of lists, each of which contains index values. For instance, A
could be of shape [1, 2048, 768]
and I can be I = [ [0, 1], [2], [3, 4, 5], ... ]
.
My goal is to create a resulting tensor Z
which aggregates the mean values from A
based off the indices provided in I
. For now, I’m doing it as below, but it seems to be much time-consuming to produce the resulting Z:
idx_mapping = {k : v for k, v in zip(I, [s for s in range(len(I))])}
Z = torch.zeros(A.size(0), len(I), A.size(2))
for k, v in idx_mapping.items():
# indices (List) which we want to take mean from tensor A
mean_indices = k
# mapped idx in the resulting vector Z
embedded_into = v
# do mean on mean_indices of tensor A
token_embed = torch.index_select(A, 1, torch.tensor(eval(f'[{mean_indices.replace("-", ",")}]')).mean(dim=1).unsqueeze(0))
# replace the calculated mean in embedded_into's index of tensor Z
Z = torch.cat((Z[:, :embedded_into, :], token_embed, Z[:, embedded_into+1: , :]), dim=1)
For reference, idx_mapping
is a dictionary that maps the indices of tensor A
to resulting tensor Z
. I have constructed its keys from list I
and values are incremented one by one. For instance, idx_mapping
can be: {'0': 0, '1-2': 1, '3-4-5': 2, ...}
Now the question is that what is the best (and possibly quickest) way to this operation in Pytorch?