Select tensor in a batch of sequences


(Felix Kreuk) #1

Hi all,

Goal: select the last item in a PackedSequence for each sequence in the batch.

I have a tensor of shape [16,80,300] (batch_size X sequence_len X features), which is the output of an LSTM. The sequences are zero-padded so I also have a list of indexes of length [16] which specify their length.

I basically want to select the last item in the sequence for each sequence in the batch (e.g. first seq may be of length 40, next one is of length 60, etc.), so I would end up with a [16,300] tensor. But, since the sequences are of variable length I cant just do input[:, -1, :], because every sequences has its own length.

I saw this thread and tried to replicate but having some trouble:

x = torch.FloatTensor(16,80,300)  # rnn output
idx = torch.LongTensor([random.randint(10,79) for i in xrange(16)])  # list of lengths
x.gather(1, idx.view(-1, 1, 1))

but I get: RuntimeError: Expected tensor [16 x 1 x 1], src [16 x 80 x 300] and index [16 x 1 x 1] to have the same size in dimension 1 at /Users/soumith/miniconda2/conda-bld/pytorch_1501999754274/work/torch/lib/TH/generic/THTensorMath.c:445.


(Marcin Elantkowski) #2

Note, that if you can feed a PackedSequence into your LSTM (of hidden size 300), you can get the last hidden states directly, by simply calling

_, last_states, _ = lstm(my_packed_tensor)

As for the error, you could do

idx = idx.view(-1, 1, 1).expand(-1, 1, hidden_dim) and then gather should work.

You could also use advanced indexing I think:

row_indices = th.arange(0, batch_size).long()
col_indices  = seq_lengths - 1
last_states_indexed = full_output[row_indices, col_indices, :]

Here’s a quick test that what I’m saying makes sense:

# ------- Import -------
import torch as th
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence as pack, pad_packed_sequence as pad
from torch.autograd import Variable as V

import numpy as np

# ------- Setup -------
th.manual_seed(42)

batch_size  = 4
max_seq_len = 5
feature_dim = 5
hidden_dim  = 3

seq_lengths  = [4, 3, 2, 1]

# model and data
lstm = nn.LSTM(feature_dim, hidden_dim, batch_first=True)
x    = V(th.randn(batch_size, max_seq_len, feature_dim))

# "padding"
for i, idx in enumerate(seq_lengths):
    x[i, idx:, :] = 0

    
# ------- Using PackedSequence -------     
x_packed = pack(x, seq_lengths, batch_first=True)
_, (last_states_packed, _) = lstm(x_packed)

print(last_states_packed.squeeze()[0, ...])  # This is a last state for first sequence in batch.
                                             # This means it is the state after 4 LSTM steps, since this sequence was of length 4

# Variable containing:
#  0.1615
# -0.0783
# -0.3117
# [torch.FloatTensor of size 3]


# ------- Using a raw, padded Tensors -------     
full_output, (last_states, _) = lstm(x)

print(last_states.squeeze()[0, ...])  # This time this is not what we want! This is after 5 LSTM steps (max sequence length)

# Variable containing:
#  0.0894
# -0.0987
# -0.3530
# [torch.FloatTensor of size 3]

print(full_output[0, seq_lengths[0]-1, :])  # This is what we want. What we got above was full_output[0, -1, :].

# Variable containing:
#  0.1615
# -0.0783
# -0.3117
# [torch.FloatTensor of size 3]
   

# ------- Extract data using gather -------   
seq_end_idx    = V(th.LongTensor(seq_lengths) - 1, requires_grad=False)
seq_end_idx_ex = seq_end_idx.view(-1, 1, 1).expand(-1, 1, hidden_dim)

last_states_sliced = full_output.gather(1, seq_end_idx_ex)

assert np.allclose(last_states_sliced.data.squeeze().numpy(), last_states_packed.data.squeeze().numpy())
   

# ------- Extract data using advanced indexing -------   

row_indices = th.arange(0, batch_size).long()
last_states_indexed = full_output[row_indices, seq_end_idx, :]

assert np.allclose(last_states_indexed.data.squeeze().numpy(), last_states_packed.data.squeeze().numpy())

(Felix Kreuk) #3

Hi Marcin!

Thanks so much for the detailed answer. I am using PackedSequence and actually wasn’t aware that the last states are already provided, so far I just packed a zero-padded sequences, fed to lstm, unpacked, and got an output sequence with zero-padding ([16,80,300]), which led me to believe I need to select them manually to get to [16,300].
I will definitely give your suggestion a try and report back!

Thanks again!


(Felix Kreuk) #4

Hey, tried the “easy way”. This is what I get:

_, (last_states_packed, _) = self.rnn(packed_seq, hidden)  # get utterance rep
ValueError: need more than 1 value to unpack

Where self.rnn is:

self.rnn = nn.GRU(input_size, hidden_size, n_layers, batch_first=True) 

Am I doing something wrong?


(Marcin Elantkowski) #5

Oh, sorry, by bad.

LSTM returns (all_outputs, (last_hidden, last_cell) (my example),
but I believe GRU returns (all_outputs, last_hidden)

In that case using
_, last_states_packed = self.rnn...
should work