Variable Length Sequences for Many-to-One RNN

Hi all,
I am trying to train a model to do audio classification of variable length sequences. I pad the sequences with zeros at the end and use pack_padded_sequence before feeding to nn.GRU and pad_packed_sequence before feeding to a fully connected layer


x = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True)
x, h = rnn(x)
x, lengths = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)

x = x.contiguous()
x = x.view(-1, rnn_code_size)
x = classifier(x)
x = x.view(-1, seq_length, classes)
return x

When training I use the CrossEntropyLoss and do:
loss = cross_entr_loss(y[:, -1, :], labels),
where y[:, -1, :] is the last sample of the sequence output by the GRU.

This is sort of working (with poor accuracy but still better than random) but I noticed that the performance is better if when I run on data with the same length sequences (from the same dataset).

Is there something that I am doing wrong w.r.t. the handling of the variable length sequences or is there some explanation why I would perform better on the fixed length sequences.

1 Like

I think you want to do the classifier on the last output sequence of each batch and then put that thru the fully connected layer.

rnn_output, hidden = gru(x)
rnn_output = unpack(rnn_output)
rnn_output_last = rnn.index_select(seqeunce_dim, Variable(LongTensor(lengths)))
final_output = classifier(rnn_output_last)
return final_output

that’s the gist. But instead of returning (batch_size, seq_length, classes), you’ll get (batch_size, classes) or (batch_size, 1, classes). You might also not need to do the index_select part and just use the last output in the sequence, but the idea is that even though your sequence is of variable length, your fully connected layer is going to take an input of the same size which is just the number of features in the rnn layer for the last output.

1 Like

Thank you for the prompt response. What does the index_select function do I can’t seem to find any documentation for it.

I actually think I did this incorrectly, What it does is select the part of the tensor in the given tensor at indexes given.

But you want to select one output for each batch. I think this would select all the lengths for all the batches. You could do something similar by combining the batch dim and the sequence dim then selecting from that…

lengths = [5, 3, 1]
#batch_size = 3
#max_sequence_length = 5
#classes = 2
a = torch.arange(30).view(3, 5, 2)
a = a.view(3*5, 2)
a.index_select(0, (torch.LongTensor(lengths)-1)*3)

The link to the docs is here

… my math with the selecting tensor is not right, but perhaps that makes a little more sense of what I’m trying to do. The naive way is

tmp = []
for b_i, l in enumerate(lengths):
    tmp.append(a[b_i, l-1, :])
torch.stack(tmp)

I think I get what you are saying. Basically I should make a tensor whose elements are the last element of each sequence in the batch. That way the tensor I get is Batch_Size x LSTM output Features. I’ll give that a go and let you know If I get a better performance.

Thank you, it worked like a charm! Here is how I did it (I use batch first in my rnns):

lengths = [5, 3, 1]
#batch_size = 5
#max_sequence_length = 3
#features = 2
a = torch.arange(30).view(3, 5, 2)
a = a.view(3*5, 2)
adjusted_lengths = [i * lengths[0] + l for i, l in enumerate(lengths)]
a.index_select(0, (torch.LongTensor(adjusted_lengths)-1))

Note: If you are planning on using this to do backprop then you will require the (torch.LongTensor(adjusted_lengths)-1) to be wrapped in a Variable. If you are running on GPU don’t forget to add the cuda.

Replying on an old thread, but here’s what worked for me:

class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, layer_dim, output_size):
        super(RNN, self).__init__()

        self.hidden_size = hidden_size
        self.layer_dim = layer_dim
        
        self.rnn = nn.RNN(input_size, hidden_size, layer_dim, batch_first=True)
        torch.nn.init.xavier_uniform(self.rnn.weight_ih_l0)
        torch.nn.init.xavier_uniform(self.rnn.weight_ih_l1)
        self.dropout = nn.Dropout(p=0.1)
        self.fc1 = nn.Linear(hidden_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
#         self.softmax = nn.Sigmoid()

    def forward(self, input, lengths):
        batch_size, seq_len, feature_len = input.size()
        input = torch.nn.utils.rnn.pack_padded_sequence(input, lengths, batch_first=True)
        output, _ = self.rnn(input)
        output, lengths = torch.nn.utils.rnn.pad_packed_sequence(output, batch_first=False)
        output = output.view(batch_size*seq_len, self.hidden_size)
        adjusted_lengths = [(l-1)*batch_size + i for i,l in enumerate(lengths)]
        lengthTensor = torch.tensor(adjusted_lengths, dtype=torch.int64)
        if useCuda:
            lengthTensor = lengthTensor.to(cuda)
        output = output.index_select(0,lengthTensor)
        output = output.view(batch_size,self.hidden_size)
        output = self.fc1(output)
        output = self.fc2(output)
        return output

I’ve put batch_first=False in the padding because the <seq, batch, features> is the contiguous view and if it is set to true, it just returns a non-contiguous view of the tensor with dimensions <batch, seq, features> which needs to be converted to contiguous, just taking up extra memory and time. This worked perfectly for me

And I’m replying to a reply to an old thread :wink:

Can you answer a simple question for me @dashu2410 ? Is the output you’re selecting with output.index_select(0,lengthTensor) essentially the output for the last item in the sequence? As in, if your data was a time series sequence, are you pulling up the RNN’s output for the most recent item in the sequence after it has seen everything else of the sequence?

Thanks in advance!

Yes, that’s exactly what I’m doing here @karmus89 :slight_smile:
Although to be fair, this has lead to an increase in my training by 5x as compared to just taking the last time output of the tensor (regardless of the input length) because there’s a lot of back and forth between the cpu and gpu, but that model would be suboptimal in performance in any case. Let me know if you happen to find a more computationally efficient way of doing this!

This is how I did

def forward(self, x, lens):
        h0 = torch.randn(self.num_layers, x.size(0), self.hidden_size)

        emb = self.embedding(x)  
        packed = nn.utils.rnn.pack_padded_sequence(
            emb, lens, batch_first=True, enforce_sorted=False)

        packed, h0 = self.gru(packed, h0)
        out = nn.utils.rnn.pad_packed_sequence(packed, batch_first=True)[0]
            
        index = (lens-1).unsqueeze(-1).repeat(1,out.size(-1)).unsqueeze(1)
        out = torch.gather(out, 1, index).squeeze()
        out = self.lin(out)
        return out
1 Like