How to get the output at the last timestep for batched sequences?

@rk2900

Thanks for posting this. I think I’ve found an easier way. For GRUs, the last hidden state is equivalent to the last output state I believe.

So you should be able to do:

outputs, hn = gru(inputs, h0)
print(hn[-1])

For the LSTM the equivalent code would be:

outputs, (hn, cn) = lstm(inputs, h0)
print(hn[-1])

I used your code to verify this.

And you can more compactly express your code by using .view() to add the unit axes:

masks = (vlens-1).view(1, -1, 1).expand(max_seq_len, outputs.size(1), outputs.size(2))
output = outputs.gather(0, masks)[0]
6 Likes