How to obtain memory states from pack padded sequence

Hi @Lewis,
You can’t really do this, mainly because nn.LSTM only returns the hidden states outputs as it’s output. However, you can implement it using nn.LSTMCell and also implement the pack padded sequence manually. I’ve used this code in one of my very recent projects:

import torch
from torch import nn

sequences = torch.LongTensor([[1, 2, 0, 0, 0, 0], 
                              [3, 4, 5, 0, 0, 0],  
                              [5, 6, 0, 0, 0, 0],  
                              [8, 9, 10, 11, 12, 0]])  
seq_lengths = torch.LongTensor([2, 3, 2, 5])

class PackedMemory(nn.Module):
    def __init__(self, vocab_size, embed_dim, decoder_dim):
        super(PackedMemory, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTMCell(embed_dim, decoder_dim)
        self.decoder_dim = decoder_dim
        
    def forward(self, seq, seq_len):
        
        batch_size = seq.size(0)
        sorted_lengths, sort_indices = torch.sort(seq_lengths, descending=True) 
        sorted_sequences = sequences[sort_indices]
        hidden_states = torch.zeros(batch_size, max(sorted_lengths), self.decoder_dim)
        memory_states = torch.zeros(batch_size, max(sorted_lengths), self.decoder_dim)
        final_hidden = torch.zeros(batch_size, self.decoder_dim)
        embeddings = self.embedding(sorted_sequences)
        h,c = [torch.zeros(batch_size, self.decoder_dim), torch.zeros(batch_size, self.decoder_dim)]
             
        for t in range(max(sorted_lengths)):
            batch_size_t = sum([l > t for l in sorted_lengths])
            h, c = self.lstm(embeddings[:batch_size_t, t, :], (h[:batch_size_t], c[:batch_size_t])) 
            hidden_states[:batch_size_t, t, :] = h
            memory_states[:batch_size_t, t, :] = c
            final_hidden[:batch_size_t] = h.clone()
            
        # Now you have a tensor just like you would if you were running pad_packed_sequence. As an additional step,
        # you may create a mask for padded words that are shorter than the max len
        mask = ((memory_states.sum(2))!=0).float()
        return hidden_states, memory_states, final_hidden, mask

packed_memory = PackedMemory(13,512,512)
hidden_states, memory_states, final_hidden, mask = packed_memory(sequences, seq_lengths)

Now if you see the sorted sequences:

tensor([[ 8.,  9., 10., 11., 12.,  0.],
        [ 3.,  4.,  5.,  0.,  0.,  0.],
        [ 1.,  2.,  0.,  0.,  0.,  0.],
        [ 5.,  6.,  0.,  0.,  0.,  0.]])

The mask:

tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 0., 0.],
        [1., 1., 0., 0., 0.],
        [1., 1., 0., 0., 0.]])

And if you print the output of the hidden_states and memory_states, you’ll get
torch.Size([4, 5, 512])

Hope this helps!