with batch_first=True. the packed input versioon still outputs result with shape (T * B * dim)
import torch
from torch import nn
from torch.autograd import Variable
from torch.nn.utils.rnn import pad_packed_sequence as unpack
from torch.nn.utils.rnn import pack_padded_sequence as pack
def sort_batch(data, seq_len):
batch_size = data.size(0)
sorted_seq_len, sorted_idx = torch.sort(seq_len, dim=0, descending=True)
sorted_data = data[sorted_idx]
_, reverse_idx = torch.sort(sorted_idx, dim=0, descending=False)
return sorted_data, sorted_seq_len, reverse_idx
data = torch.rand(4,7,10)
lens = torch.LongTensor([3,7,2,1])
s_data, s_len, reverse_idx = sort_batch(data, lens)
emb = pack(Variable(s_data), list(s_len),batch_first=True)
rnn = nn.LSTM(10, 20, 2, batch_first =True, bidirectional=False)
input = emb
h0 = Variable(torch.randn(2, 6, 20))
c0 = Variable(torch.randn(2, 6, 20))
packed_output, hn = rnn(input)
result, lens = unpack(packed_output)
print(result.size())
import torch
from torch import nn
from torch.autograd import Variable
from torch.nn.utils.rnn import pad_packed_sequence as unpack
from torch.nn.utils.rnn import pack_padded_sequence as pack
data = Variable(torch.rand(4,7,10))
lens = torch.LongTensor([3,7,2,1])
#s_data, s_len, reverse_idx = sort_batch(data, lens)
rnn = nn.LSTM(10, 20, 2, batch_first =True, bidirectional=False)
h0 = Variable(torch.randn(2, 6, 20))
c0 = Variable(torch.randn(2, 6, 20))
output, hn = rnn(data)
#result, lens = unpack(packed_output)
print(output.size())