How efficiently slice a multi-dimension tensor

For example, there is a 3D tensor whose dimseion is (10, 5, 100), (TimeStep, BS, HiddenDimension). I wanna slice the tensor along the TimeStep dimension, with a tensor whose shape is (8, 5, 100) as output . The problem is, for each sample in batch the selective index is different, namely a mask matrix with shape (8, 5). One intuitive solution is as follows:

tmp = torch.zeros(8, 5, 100)
for i in range(5):
     tmp[:,i, :] = input_tensor[mask_matrix[:, i], i, :]
input_tensor = tmp

The operation will be done in Module’s forward func, which I think may slow down the computation. So, any suggestion about effieient solutions?

1 Like

It mask_matrix really a mask, i.e. it contains zeros and ones or does it store indices?

1 Like

@ptrblck Sorry for the ambiguity, mask_matrix stores indices.

Bump. Anyone got an answer?

Bump. Anyone got an answer?

do you get the answer?