I have a 3D tensor called x of size (batch_size, sequence_len, features).
For each batch element I want to take 3 rows, whose indexes could be different from batch element to batch element.
With a classic for loop over batch dimension it would look like:
res = torch.empty(x.size(0),3,x.size(2)) # take 3 rows for each batch
for i in range(x.size(0)):
res[i] = x[i][ rows[i] ] # select appropriate rows
where rows is a LongTensor containing the indexes of the 3 rows to be taken for each batch element.
Is there a way to get rid of the for loop? I tried many ways but I am not able to make it work for a 3D tensor.
For all of you that are interested, I think I found a solution, after many trials and errors:
the trick is to set the torch.LongTensor used as index to have the same 0-th dimension value for each dimension of the original tensor.
The code is very simple:
# select first and second row for each batch element
rows = torch.LongTensor( [ [0,1]*batch_size ] )
# create an index for each batch element
idx = [
torch.LongTensor(range(batch_size)).unsqueeze(1),
torch.LongTensor(rows)
]
# index the tensor
x[idx] # shape (batch_size, 2, features)
The unsqueeze(1) creates a tensor like this: [[0],[1],[2]...,[batch_size-1]].
By simply changing [0,1] inside the rows tensor it is possible to get different rows for each batch element.