Hey all,
I noticed that using that the output of my LSTM suffers when I use pack_padded_sequence and pad_packed_sequence respectively, which is weird to since it should not have any impact, or?
I attach a code example showing the dot product of a matrix of word embeddings with and without pack padded sequence.
This is what I expect:
from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn.utils.rnn import pad_packed_sequence
documents = torch.tensor([[299, 300, 301, 302, 303, 301, 304, 305, 306, 301, 302, 307, 308, 303,
301, 304, 305, 306, 309, 307, 310, 311, 312, 313, 314, 309, 315, 316,
317, 318, 195, 299, 319, 310, 311, 320, 221, 321, 322, 310, 311, 323,
100, 310, 311, 324, 100, 310, 311, 325, 326, 100, 39, 327, 328, 321,
329, 330, 331, 332, 150, 333, 211, 334, 310, 311, 325, 310, 311, 312,
309, 333, 335, 131, 219, 326, 100, 324, 100, 331, 336, 337, 299, 338,
323, 100, 331, 339, 310, 311, 325, 334, 340, 341, 323, 100, 324, 100,
331, 310]])
lengths = torch.tensor([100])
lstm = nn.LSTM(300, 300, 2, bidirectional=True)
x = word_embedding(documents)
x, _ = lstm(x)
sns.heatmap(torch.matmul(x, x.transpose(-1, -2)).detach()[0])
And this is what I get when using a packed sequence:
from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn.utils.rnn import pad_packed_sequence
documents = torch.tensor([[299, 300, 301, 302, 303, 301, 304, 305, 306, 301, 302, 307, 308, 303,
301, 304, 305, 306, 309, 307, 310, 311, 312, 313, 314, 309, 315, 316,
317, 318, 195, 299, 319, 310, 311, 320, 221, 321, 322, 310, 311, 323,
100, 310, 311, 324, 100, 310, 311, 325, 326, 100, 39, 327, 328, 321,
329, 330, 331, 332, 150, 333, 211, 334, 310, 311, 325, 310, 311, 312,
309, 333, 335, 131, 219, 326, 100, 324, 100, 331, 336, 337, 299, 338,
323, 100, 331, 339, 310, 311, 325, 334, 340, 341, 323, 100, 324, 100,
331, 310]])
lengths = torch.tensor([100])
lstm = nn.LSTM(300, 300, 2, bidirectional=True)
x = word_embedding(documents)
x = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
x, _ = lstm(x)
x, _ = pad_packed_sequence(x, batch_first=True)
sns.heatmap(torch.matmul(x, x.transpose(-1, -2)).detach()[0])
Appreciate any help!