Thanks for posting this. I think I’ve found an easier way. For GRUs, the last hidden state is equivalent to the last output state I believe.
So you should be able to do:
outputs, hn = gru(inputs, h0)
print(hn[-1])
For the LSTM the equivalent code would be:
outputs, (hn, cn) = lstm(inputs, h0)
print(hn[-1])
I used your code to verify this.
And you can more compactly express your code by using .view()
to add the unit axes:
masks = (vlens-1).view(1, -1, 1).expand(max_seq_len, outputs.size(1), outputs.size(2))
output = outputs.gather(0, masks)[0]