I want to process groups of rows in 2D-tensor and calculate simple statistics(eg. mean or max). I know how to do it step by step, but that’s an extremely slow approach for big tensors.

Is there a way to use fancy indexing or something similar to achieve this goal?

Example of what I am trying to do:

X = torch.randn(100, 10)
list_of_indexes = [
[1, 4, 7, 10],
[10, 60, 40, 10, 5, 3, 0],
[6, 4]
]
for indexes in list_of_indexes:
print(X[indexes].mean(dim=0))

for indexes in list_of_indexes:
print(torch.index_select(X, 0, torch.LongTensor(indexes)))

My toy sample is

In [1]: import torch
In [2]: import numpy as np
In [3]: x = np.random.normal(size=(10, 4)).astype(np.float32)
In [4]: x = torch.from_numpy(x)
In [5]: torch.index_select(x, 0, torch.LongTensor([1, 3]))
Out[5]:
tensor([[-0.5863, -1.2383, -0.8615, 0.1059],
[-0.2208, 0.7738, 1.1371, -0.9917]])