Simple working example how to use packing for variable-length sequence inputs for rnn

Agree.
The reason that the code can run without error is that batch_size is set to be equal to max_length. It won’t work if you change either of them. And the first parameter of nn.RNN should be input_size rather than the maximal sequence length. Besides, I would prefer writing vec_1 = torch.FloatTensor([[1], [2], [3]]) than vec_1 = torch.FloatTensor([[1, 2, 3]]) but here both are fine.

Here is a modified version:

batch_size = 4
max_length = 3
hidden_size = 2
n_layers =1
feature_dim = 1

# container
batch_in = torch.zeros((batch_size, max_length, feature_dim))

# data
vec_1 = torch.FloatTensor([[1, 2, 3]])
vec_2 = torch.FloatTensor([[1, 2, 0]])
vec_3 = torch.FloatTensor([[1, 0, 0]])
vec_4 = torch.FloatTensor([[2, 0, 0]])

batch_in[0] = vec_1
batch_in[1] = vec_2
batch_in[2] = vec_3
batch_in[3] = vec_4

batch_in = Variable(batch_in)
print(batch_in.size())
seq_lengths = [3,2,1,1] # list of integers holding information about the batch size at each sequence step

# pack it
pack = torch.nn.utils.rnn.pack_padded_sequence(batch_in, seq_lengths, batch_first=True)
print(pack)

rnn = nn.RNN(feature_dim, hidden_size, n_layers, batch_first=True) 
h0 = Variable(torch.randn(n_layers, batch_size, hidden_size))

#forward 
out, _ = rnn(pack, h0)

# unpack
unpacked, unpacked_len = torch.nn.utils.rnn.pad_packed_sequence(out, batch_first=True)
print(unpacked)

and the output:

torch.Size([4, 3, 1])

PackedSequence(data=Variable containing:
    1
    1
    1
    2
    2
    2
    3
[torch.FloatTensor of size 7x1]
, batch_sizes=[4, 2, 1])

Variable containing:
(0 ,.,.) = 
 -0.8313 -0.7238
 -0.9355  0.3213
 -0.9907 -0.0606

(1 ,.,.) = 
 -0.8365 -0.0670
 -0.9559 -0.0762
  0.0000  0.0000

(2 ,.,.) = 
 -0.2423 -0.1630
  0.0000  0.0000
  0.0000  0.0000

(3 ,.,.) = 
 -0.9419  0.0727
  0.0000  0.0000
  0.0000  0.0000
[torch.FloatTensor of size 4x3x2]

We can see that the last row of the second output is a zero vector, this is reasonable because we don’t intend to feed the PAD symbol into RNN.

8 Likes