I use max_seq_len * batch_size * embed_size
for batch input, also with a list of actual lengthes for each sequence to GRU/LSTM,
then I get the outputs and last hidden vector with size of max_seq_len * batch_size * hidden_size
and layer_num * batch_size * hidden_size
.
How can I get the actual output vector at the last (actual) time step for each sequence? whose size should be batch_size * hidden_size
.
The test code for my question:
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch import optim
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
import numpy as np
max_seq_len = 3
batch_size = 4
layer_num = 2
input_size = 10 # 0 - 9
emb_size = 8
hidden_size = 10
# test data:
a = [[1,2,3], [4,5,0], [6,0,0], [7,8,9]]
lens = [3,2,1,3]
# sort the input batch data by reversed actual length for pad_pack operation
pairs = sorted( zip(a, lens), key=lambda p: p[1], reverse=True)
(a, lens) = zip(*pairs)
# actual length
lens = np.array(lens)
# lens = torch.LongTensor(lens)
va = Variable(torch.LongTensor(a))
vlens = Variable(torch.LongTensor(lens))
print(va, vlens)
embedding = nn.Embedding(input_size, emb_size)
gru = nn.GRU(emb_size, hidden_size, layer_num)
inputs = va.transpose(0, 1); print("inputs size: ", inputs.size()) # max_seq_len * batch_size
inputs = embedding(inputs); print("embedded size: ", embedded.size()) # max_seq_len * batch_size * emb_size
inputs = torch.nn.utils.rnn.pack_padded_sequence(inputs, lens)
h0 = Variable(torch.randn(layer_num, 4, 10))
outputs, hn = gru(inputs, h0)
# print("after packed:")
# print("outputs.size: ", outputs.size())
# print("hn size: ", hn.size())
tmp = torch.nn.utils.rnn.pad_packed_sequence(outputs)
outputs, output_lengths = tmp
# outputs: (max_seq_len * batch_size, output_size)
print("after padded:")
print("outputs.size: ", outputs.size())
print("hn size: ", hn.size())
# outputs.index_select(torch.LongTensor(lens-1))
# print("lens: ", lens-1)
# idxs = Variable(torch.LongTensor(lens-1))
# outputs = outputs.index_select(0, idxs)
# print("selected outputs size: ", outputs.size())
# outputs.gather(0, idxs)