Hi @Lewis,
You can’t really do this, mainly because nn.LSTM only returns the hidden states outputs as it’s output
. However, you can implement it using nn.LSTMCell and also implement the pack padded sequence manually. I’ve used this code in one of my very recent projects:
import torch
from torch import nn
sequences = torch.LongTensor([[1, 2, 0, 0, 0, 0],
[3, 4, 5, 0, 0, 0],
[5, 6, 0, 0, 0, 0],
[8, 9, 10, 11, 12, 0]])
seq_lengths = torch.LongTensor([2, 3, 2, 5])
class PackedMemory(nn.Module):
def __init__(self, vocab_size, embed_dim, decoder_dim):
super(PackedMemory, self).__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.lstm = nn.LSTMCell(embed_dim, decoder_dim)
self.decoder_dim = decoder_dim
def forward(self, seq, seq_len):
batch_size = seq.size(0)
sorted_lengths, sort_indices = torch.sort(seq_lengths, descending=True)
sorted_sequences = sequences[sort_indices]
hidden_states = torch.zeros(batch_size, max(sorted_lengths), self.decoder_dim)
memory_states = torch.zeros(batch_size, max(sorted_lengths), self.decoder_dim)
final_hidden = torch.zeros(batch_size, self.decoder_dim)
embeddings = self.embedding(sorted_sequences)
h,c = [torch.zeros(batch_size, self.decoder_dim), torch.zeros(batch_size, self.decoder_dim)]
for t in range(max(sorted_lengths)):
batch_size_t = sum([l > t for l in sorted_lengths])
h, c = self.lstm(embeddings[:batch_size_t, t, :], (h[:batch_size_t], c[:batch_size_t]))
hidden_states[:batch_size_t, t, :] = h
memory_states[:batch_size_t, t, :] = c
final_hidden[:batch_size_t] = h.clone()
# Now you have a tensor just like you would if you were running pad_packed_sequence. As an additional step,
# you may create a mask for padded words that are shorter than the max len
mask = ((memory_states.sum(2))!=0).float()
return hidden_states, memory_states, final_hidden, mask
packed_memory = PackedMemory(13,512,512)
hidden_states, memory_states, final_hidden, mask = packed_memory(sequences, seq_lengths)
Now if you see the sorted sequences:
tensor([[ 8., 9., 10., 11., 12., 0.],
[ 3., 4., 5., 0., 0., 0.],
[ 1., 2., 0., 0., 0., 0.],
[ 5., 6., 0., 0., 0., 0.]])
The mask:
tensor([[1., 1., 1., 1., 1.],
[1., 1., 1., 0., 0.],
[1., 1., 0., 0., 0.],
[1., 1., 0., 0., 0.]])
And if you print the output of the hidden_states
and memory_states
, you’ll get
torch.Size([4, 5, 512])
Hope this helps!