I have a 3D tensor called
x of size
(batch_size, sequence_len, features).
For each batch element I want to take 3 rows, whose indexes could be different from batch element to batch element.
With a classic for loop over batch dimension it would look like:
res = torch.empty(x.size(0),3,x.size(2)) # take 3 rows for each batch
for i in range(x.size(0)):
res[i] = x[i][ rows[i] ] # select appropriate rows
rows is a
LongTensor containing the indexes of the 3 rows to be taken for each batch element.
Is there a way to get rid of the for loop? I tried many ways but I am not able to make it work for a 3D tensor.
For all of you that are interested, I think I found a solution, after many trials and errors:
the trick is to set the
torch.LongTensor used as index to have the same 0-th dimension value for each dimension of the original tensor.
The code is very simple:
# select first and second row for each batch element
rows = torch.LongTensor( [ [0,1]*batch_size ] )
# create an index for each batch element
idx = [
# index the tensor
x[idx] # shape (batch_size, 2, features)
unsqueeze(1) creates a tensor like this:
By simply changing
[0,1] inside the
rows tensor it is possible to get different rows for each batch element.