Inconsistent behavior of RNN module with packed input

with batch_first=True. the packed input versioon still outputs result with shape (T * B * dim)

import torch
from torch import nn
from torch.autograd import Variable
from torch.nn.utils.rnn import pad_packed_sequence as unpack
from torch.nn.utils.rnn import pack_padded_sequence as pack

def sort_batch(data, seq_len):
    batch_size = data.size(0)
    sorted_seq_len, sorted_idx  = torch.sort(seq_len, dim=0, descending=True)
    sorted_data = data[sorted_idx]
    _, reverse_idx  = torch.sort(sorted_idx, dim=0, descending=False)
    return sorted_data, sorted_seq_len, reverse_idx

data = torch.rand(4,7,10)
lens = torch.LongTensor([3,7,2,1])

s_data, s_len, reverse_idx = sort_batch(data, lens)

emb = pack(Variable(s_data), list(s_len),batch_first=True)

rnn = nn.LSTM(10, 20, 2, batch_first =True, bidirectional=False)
input = emb
h0 = Variable(torch.randn(2, 6, 20))
c0 = Variable(torch.randn(2, 6, 20))
packed_output, hn = rnn(input)
result, lens = unpack(packed_output)
print(result.size())

import torch
from torch import nn
from torch.autograd import Variable
from torch.nn.utils.rnn import pad_packed_sequence as unpack
from torch.nn.utils.rnn import pack_padded_sequence as pack

data = Variable(torch.rand(4,7,10))
lens = torch.LongTensor([3,7,2,1])

#s_data, s_len, reverse_idx = sort_batch(data, lens)

rnn = nn.LSTM(10, 20, 2, batch_first =True, bidirectional=False)

h0 = Variable(torch.randn(2, 6, 20))
c0 = Variable(torch.randn(2, 6, 20))
output, hn = rnn(data)

#result, lens = unpack(packed_output)
print(output.size())

You need to provide batch_first = True to unpack as well.

1 Like

Doesn’t pack_padded_sequence store the representation as T x B x dim?

So calling the lstm should be with batch_first=False?

Similarly with the pad_packed_sequence. No?