Hi,
I wanted to know what is the most efficient way of extracting certain indices from multi-dimensional tensors. Consider that I have a batch of padded sequences, of size batchsize*maxlength*inputsize
, and after passing it through a RNN, I end up with batchsize*maxlength*outputsize
. If I have the true length for each sequence in a tensor seq_len
, what would be the fastest way of extracting out the corresponding elements, so as to get a batchsize*outputsize
tensor?
(input
is of size batchsize*maxlength*inputsize
, while output
is of size batchsize*maxlength*outputsize
, and seq_len
is of size batchsize
. Finally, output_extracted
is of size batchsize*outputsize
)
I have currently tried the following methods:
- Using
gather
:
seq_len = seq_len.view(-1, 1, 1).expand(output.size(0), 1, output.size(2))
output_extracted = torch.gather(output, 1, seq_len)
- Using
index_select
:
output_extracted = torch.cat([torch.index_select(output[i], 0, seq_len[i]) for i in xrange(output.size(0)])
The first one is faster than the second one, because of the list comprehension, but requires some resizing and expanding. Is there any other way to do this? Am I using index_select
incorrectly?