I wanted to know what is the most efficient way of extracting certain indices from multi-dimensional tensors. Consider that I have a batch of padded sequences, of size batchsize*maxlength*inputsize, and after passing it through a RNN, I end up with batchsize*maxlength*outputsize. If I have the true length for each sequence in a tensor seq_len, what would be the fastest way of extracting out the corresponding elements, so as to get a batchsize*outputsize tensor?

(input is of size batchsize*maxlength*inputsize, while output is of size batchsize*maxlength*outputsize, and seq_len is of size batchsize. Finally, output_extracted is of size batchsize*outputsize)

I have currently tried the following methods:

  • Using gather:
    seq_len = seq_len.view(-1, 1, 1).expand(output.size(0), 1, output.size(2))
    output_extracted = torch.gather(output, 1, seq_len)
  • Using index_select:
    output_extracted =[torch.index_select(output[i], 0, seq_len[i]) for i in xrange(output.size(0)])

The first one is faster than the second one, because of the list comprehension, but requires some resizing and expanding. Is there any other way to do this? Am I using index_select incorrectly?


IMO, Gather and index_select are the only torch functions that can do this trick.

