RNN with different sequence lengths

Hello,
I am working on a time series dataset using LSTM. Each sequence has the following dimension “S_ix6”, e.g. the sequences have different lengths. I first created a network (netowrk1), and in the “forward” function padded each sequence, so they have the same length. But unfortunately, the networks could not really learn the structures in the data. So I decided to not pad the sequences and rewrote the network (network2) so that in the forward pass there is a for-loop over each sequence in a batch, whereas mentioned before they have different lengths. And lo and behold, the network converges much better! Now my question is:

Questions:

  • What is really the effect of padding on the network?
  • Why padding the sequences ends in a worse convergence result?

Network 1: With padding

class DeepIO(nn.Module):
    def __init__(self):
        super(DeepIO, self).__init__()
        self.rnn = nn.LSTM(input_size=6, hidden_size=512,
                           num_layers=2, bidirectional=True)
        self.drop_out = nn.Dropout(0.25)
        self.fc1 = nn.Linear(512, 256)
        self.bn1 = nn.BatchNorm1d(256)
        self.fc_out = nn.Linear(256, 7)

    def forward(self, x):
        """
        args:
        x:  a list of inputs of diemension [BxTx6]
        """
        lengths = [x_.size(0) for x_ in x]   # get the length of each sequence in the batch
        x_padded = nn.utils.rnn.pad_sequence(x, batch_first=True)  # padd all sequences
        b, s, n = x_padded.shape
        
        # pack padded sequece
        x_padded = nn.utils.rnn.pack_padded_sequence(x_padded, lengths=lengths, batch_first=True, enforce_sorted=False)
        
        # calc the feature vector from the latent space 
        out, hidden = self.rnn(x_padded)
        
        # unpack the featrue vector
        out, lens_unpacked = nn.utils.rnn.pad_packed_sequence(out, batch_first=True)
        out = out.view(b, s, self.num_dir, self.hidden_size[0])

        # many-to-one rnn, get the last result
        y = out[:, -1, 0]

        y = F.relu(self.fc1(y), inplace=True)
        y = self.bn1(y)
        y = self.drop_out(y)

        y = self.out(y)
        return y

Network 2: Without padding

class DeepIO(nn.Module):
    def __init__(self):
        super(DeepIO, self).__init__()
        self.rnn = nn.LSTM(input_size=6, hidden_size=512,
                           num_layers=2, bidirectional=True)
        self.drop_out = nn.Dropout(0.25)
        self.fc1 = nn.Linear(512, 256)
        self.bn1 = nn.BatchNorm1d(256)
        self.fc_out = nn.Linear(256, 7)

    def forward(self, x):
        """
        args:
        x:  a list of inputs of diemension [BxTx6]
        """
        outputs = []
        # iterate in the batch through all sequences
        for xx in x:
            s, n = xx.shape
            out, hiden = self.rnn(xx.unsqueeze(1))
            out = out.view(s, 1, 2, 512)
            out = out[-1, :, 0]
            outputs.append(out.squeeze())
        outputs = torch.stack(outputs)

        y = F.relu(self.fc1(outputs), inplace=True)
        y = self.bn1(y)
        y = self.drop_out(y)
        y = self.out(y)
        return y

Thanks
Arash

Hi,

Sorry for not answering your question but how did you manage to design your neural network with variable sequence length?

I am working on audio data and I am working by padding them.
Thanks.
BR,
Shweta.

@shwe87 by just iterating through each sequence in a batch. See the for-loop in the forward-function. I think it is possible, because i do not pass the hidden state between each sequences. So e.g. no state-sharing or stateless.

You could consider generating batches with sequences of the same length. I use it all the time for sequence classification but also for seq2seq models. You may want to have a look here, here and here.

@vdw thanks for the hint and the links. Yes, reordering sequences so they have the same length in each batch is a nice idea. But actually I am wondering what is the effect of padding, packing and unpacking on the performance of the network, e.g. I mean not the computational performance but the loss performance. Why does padding result in a bad performance? By the way the project I am talking about can be found here (deeplio). I would be thankful for any link or paper handling this topic!

Hello,

From my understanding, in the implementation of your first network. You take the last time step y = out[:, -1, 0] to decode is incorrect. Because the output of pad_packed_sequence is containing a lot of zero padded to make sure all sequences in this batch has the same shape. Therefore, you have to rely on the lens_unpacked to select the correct last time step to decode

Hey hai how can I give input to LSTM of variable lengths like [[1,2,3,4,5,6,7,8]] and [[1,2,3,4,5,6,7,8,1,2,3,4,5,6,7,8]] and How can I able to test the model on any size like 1x8 or 1x16??

1 Like