Index 3D tensor over multiple rows for each main element

Hello all,

I have a 3D tensor called x of size (batch_size, sequence_len, features).
For each batch element I want to take 3 rows, whose indexes could be different from batch element to batch element.
With a classic for loop over batch dimension it would look like:

res = torch.empty(x.size(0),3,x.size(2)) # take 3 rows for each batch
for i in range(x.size(0)):
    res[i] = x[i][ rows[i] ] # select appropriate rows

where rows is a LongTensor containing the indexes of the 3 rows to be taken for each batch element.

Is there a way to get rid of the for loop? I tried many ways but I am not able to make it work for a 3D tensor.

Thank you!

For all of you that are interested, I think I found a solution, after many trials and errors:

the trick is to set the torch.LongTensor used as index to have the same 0-th dimension value for each dimension of the original tensor.

The code is very simple:

# select first and second row for each batch element
rows = torch.LongTensor( [ [0,1]*batch_size ] ) 
# create an index for each batch element
idx = [
# index the tensor
x[idx] # shape (batch_size, 2, features)

The unsqueeze(1) creates a tensor like this: [[0],[1],[2]...,[batch_size-1]].
By simply changing [0,1] inside the rows tensor it is possible to get different rows for each batch element.