How to obtain memory states from pack padded sequence

Hello,
I would like to ask how can I obtain the memory states outputs (not hidden states) of each cell in an LSTM when using pack padded sequence?

For example, this code extracts the hidden states:

 packed_embedded = pack_padded_sequence(embedded, 
                                        src_len, 
                                        batch_first = True,
                                        enforce_sorted = False) # or sort then set to true (default: true)
        
 packed_outputs, hidden = self.lstm_encoder(packed_embedded)  #hidden of shape (batch_size, 2, hidden_size)

I would like my packed_outputs to be from the memory state not from the hidden state. Is there any way to do this?

Thanks!

Hi @Lewis,
You can’t really do this, mainly because nn.LSTM only returns the hidden states outputs as it’s output. However, you can implement it using nn.LSTMCell and also implement the pack padded sequence manually. I’ve used this code in one of my very recent projects:

import torch
from torch import nn

sequences = torch.LongTensor([[1, 2, 0, 0, 0, 0], 
                              [3, 4, 5, 0, 0, 0],  
                              [5, 6, 0, 0, 0, 0],  
                              [8, 9, 10, 11, 12, 0]])  
seq_lengths = torch.LongTensor([2, 3, 2, 5])

class PackedMemory(nn.Module):
    def __init__(self, vocab_size, embed_dim, decoder_dim):
        super(PackedMemory, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTMCell(embed_dim, decoder_dim)
        self.decoder_dim = decoder_dim
        
    def forward(self, seq, seq_len):
        
        batch_size = seq.size(0)
        sorted_lengths, sort_indices = torch.sort(seq_lengths, descending=True) 
        sorted_sequences = sequences[sort_indices]
        hidden_states = torch.zeros(batch_size, max(sorted_lengths), self.decoder_dim)
        memory_states = torch.zeros(batch_size, max(sorted_lengths), self.decoder_dim)
        final_hidden = torch.zeros(batch_size, self.decoder_dim)
        embeddings = self.embedding(sorted_sequences)
        h,c = [torch.zeros(batch_size, self.decoder_dim), torch.zeros(batch_size, self.decoder_dim)]
             
        for t in range(max(sorted_lengths)):
            batch_size_t = sum([l > t for l in sorted_lengths])
            h, c = self.lstm(embeddings[:batch_size_t, t, :], (h[:batch_size_t], c[:batch_size_t])) 
            hidden_states[:batch_size_t, t, :] = h
            memory_states[:batch_size_t, t, :] = c
            final_hidden[:batch_size_t] = h.clone()
            
        # Now you have a tensor just like you would if you were running pad_packed_sequence. As an additional step,
        # you may create a mask for padded words that are shorter than the max len
        mask = ((memory_states.sum(2))!=0).float()
        return hidden_states, memory_states, final_hidden, mask

packed_memory = PackedMemory(13,512,512)
hidden_states, memory_states, final_hidden, mask = packed_memory(sequences, seq_lengths)

Now if you see the sorted sequences:

tensor([[ 8.,  9., 10., 11., 12.,  0.],
        [ 3.,  4.,  5.,  0.,  0.,  0.],
        [ 1.,  2.,  0.,  0.,  0.,  0.],
        [ 5.,  6.,  0.,  0.,  0.,  0.]])

The mask:

tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 0., 0.],
        [1., 1., 0., 0., 0.],
        [1., 1., 0., 0., 0.]])

And if you print the output of the hidden_states and memory_states, you’ll get
torch.Size([4, 5, 512])

Hope this helps!

Thanks a ton for this beautiful code! Appreciate it.