How to get the output at the last timestep for batched sequences?

I use max_seq_len * batch_size * embed_size for batch input, also with a list of actual lengthes for each sequence to GRU/LSTM,
then I get the outputs and last hidden vector with size of max_seq_len * batch_size * hidden_size and layer_num * batch_size * hidden_size.
How can I get the actual output vector at the last (actual) time step for each sequence? whose size should be batch_size * hidden_size.

The test code for my question:

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch import optim
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
import numpy as np

max_seq_len = 3
batch_size = 4
layer_num = 2

input_size = 10 # 0 - 9
emb_size = 8
hidden_size = 10

# test data:
a = [[1,2,3], [4,5,0], [6,0,0], [7,8,9]]
lens = [3,2,1,3]

# sort the input batch data by reversed actual length for pad_pack operation
pairs = sorted( zip(a, lens), key=lambda p: p[1], reverse=True)
(a, lens) = zip(*pairs)

# actual length
lens = np.array(lens)
# lens = torch.LongTensor(lens)

va = Variable(torch.LongTensor(a))
vlens = Variable(torch.LongTensor(lens))

print(va, vlens)

embedding = nn.Embedding(input_size, emb_size)
gru = nn.GRU(emb_size, hidden_size, layer_num)

inputs = va.transpose(0, 1); print("inputs size: ", inputs.size()) # max_seq_len * batch_size
inputs = embedding(inputs); print("embedded size: ", embedded.size()) # max_seq_len * batch_size * emb_size

inputs = torch.nn.utils.rnn.pack_padded_sequence(inputs, lens)

h0 = Variable(torch.randn(layer_num, 4, 10))
outputs, hn = gru(inputs, h0)

# print("after packed:")
# print("outputs.size: ", outputs.size())
# print("hn size: ", hn.size())

tmp = torch.nn.utils.rnn.pad_packed_sequence(outputs)

outputs, output_lengths = tmp
# outputs: (max_seq_len * batch_size, output_size)

print("after padded:")
print("outputs.size: ", outputs.size())
print("hn size: ", hn.size())

# outputs.index_select(torch.LongTensor(lens-1))
# print("lens: ", lens-1)

# idxs = Variable(torch.LongTensor(lens-1))
# outputs = outputs.index_select(0, idxs)

# print("selected outputs size: ", outputs.size())
# outputs.gather(0, idxs)
1 Like

My direct question is

How to implement [ outputs[lens[i]][i] for i in range(len(lens)) ] ?
which is to get the last output vector at the last timestep for each sequence.

Thank you.

My own solution is

masks = (vlens-1).unsqueeze(0).unsqueeze(2).expand(max_seq_len, outputs.size(1), outputs.size(2))
output = outputs.gather(0, masks)[0]
3 Likes

@rk2900

Thanks for posting this. I think I’ve found an easier way. For GRUs, the last hidden state is equivalent to the last output state I believe.

So you should be able to do:

outputs, hn = gru(inputs, h0)
print(hn[-1])

For the LSTM the equivalent code would be:

outputs, (hn, cn) = lstm(inputs, h0)
print(hn[-1])

I used your code to verify this.

And you can more compactly express your code by using .view() to add the unit axes:

masks = (vlens-1).view(1, -1, 1).expand(max_seq_len, outputs.size(1), outputs.size(2))
output = outputs.gather(0, masks)[0]
5 Likes

Why are we expanding in the 0th dimension also? Wouldn’t the following return the same thing? If they are same, is the following faster? (because of less number of gather calls)

masks = (vlens-1).unsqueeze(0).unsqueeze(2).expand(1, outputs.size(1), outputs.size(2))
output = outputs.gather(0, masks)

I have a question related to this post so I thought I would follow up here rather than creating a new post. I’ve currently been doing the following to extract the output at the last (non-padded) timestep for an LSTM I’m using (note: batch_first is True here):

def encoder(self, input: torch.Tensor, lengths: torch.Tensor)
    # pack our zero-padded sequences
    packed_seq = nn.utils.rnn.pack_padded_sequence(input, lengths, batch_first=True)

    # pass input through lstm
    packed_output, _ = self.lstm(packed_seq)

    # unpack the output which is (batch x seq_len x num_directions * hidden_size)
    total_length = input.size(1)
    output, _ = torch.nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True, total_length=total_length)

    # grab the final output for each sequence in the batch
    seq_len_indices = [length - 1 for length in lengths]
    batch_indices = [i for i in range(self.batch_size)]
    final_output = output[batch_indices, seq_len_indices, :]

    return final_output

My final_output has the right dimensions (batch_size x hidden_size) and seems to get output from the right timesteps, but I’m wondering if I need to be using the torch operations people used above^ (e.g. the gather function with masks) in order for autograd to backpropogate the loss correctly? Will slicing the output tensor using regular lists mess up the backprop?

1 Like

Slicing via lists doesn’t mess up any gradient computations. Explicitly:

import torch
from torch.nn.utils.rnn import pad_sequence

h1 = torch.tensor([1., 1, 2, 3], requires_grad=True)
h2 = torch.tensor([5., 5], requires_grad=True)
h = pad_sequence([h1, h2])
lengths = torch.tensor([len(h1), len(h2)]).long()
batch_size = len(lengths)

# list slicing
final_states_list = h[[l-1 for l in lengths], [i for i in range(batch_size)]]

# gather via mask
mask = (lengths-1).view(1, -1).expand(h.size(0), h.size(1))
final_states_mask = h.gather(0, mask)[0]

#check grads both ways
final_states_list.sum().backward(retain_graph=True)
print(h1.grad, h2.grad)
# tensor([0., 0., 0., 1.]) tensor([0., 1.])
h1.grad.zero_()
h2.grad.zero_()
final_states_mask.sum().backward()
print(h1.grad, h2.grad)
# tensor([0., 0., 0., 1.]) tensor([0., 1.])

I’m not sure what the advantage of using gather() here is, since all rows but the first of the gather()'d tensor are thrown away. The list slicing seems cleaner.

1 Like

This is a great response, thank you for the example it was very informative!

1 Like