In Translation with a Sequence to Sequence Network and Attention, the author trained seq2seq model with batch size = 1.
I want to train seq2seq model with batch size bigger than 1.
To deal with the different length of each input sequence, we can use PackedSequence as our input.
However, I found out that PackedSequence can only be used as the input of RNN modules.
There are some layers such as
nn.Embedding in encoder and decoder, which do not take PackedSequence as input.
I think that maybe the Variable can be packed as PackedSequence before entering the rnn layer, and use
torch.nn.utils.rnn.pad_packed_sequence to pad the output of rnn layer.
But here’s the problem. Due to the attention mechanism, the sequence needs to be processed word by word. It means that I need to divide the PackedSequence into smaller PackedSequence with size ( 1, batch_size, * ). It become less intuitive.
Anyone have any idea to solve this problem?
Thanks for answering !
There’s a work-in-progress batched version of the tutorial coming up, the easiest way I could find was to do packing and unpacking within the encoder:
def __init__(self, input_size, hidden_size, n_layers=1, dropout=0.1):
self.input_size = input_size
self.hidden_size = hidden_size
self.n_layers = n_layers
self.dropout = dropout
self.embedding = nn.Embedding(input_size, hidden_size)
self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout=self.dropout, bidirectional=True)
def forward(self, input_seqs, input_lengths, hidden=None):
# Note: we run this all at once (over multiple batches of multiple sequences)
embedded = self.embedding(input_seqs)
packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths)
outputs, hidden = self.gru(packed, hidden)
outputs, output_lengths = torch.nn.utils.rnn.pad_packed_sequence(outputs) # unpack (back to padded)
outputs = outputs[:, :, :self.hidden_size] + outputs[:, : ,self.hidden_size:] # Sum bidirectional outputs
return outputs, hidden
It would be nice if other layers had support for PackedSequence…
@spro thanks for your answer!
But I have another questions,
- The first output of encoder with bidirectional rnn should depends on the first and the whole input sequence in reverse order. However, the encoder input is fed word by word, before the encoder output the first output, it should only see the first word of the input sequence. The other words aren’t fed in the encoder yet. Do I misunderstand something?
- Due to the bidirectional rnn, by taking n_layer=1, the last hidden state is in the shape of (2, batch_size, hidden_size). But the rnn in decoder is not bidirectional, so the hidden state which fed into decoder should be in the shape of (1, batch_size, hidden_size). So what should be done to the hidden state of the encoder? Should it be
hidden = hidden[0, :, :] + hidden[1, :, :]
- In the encoder above the sequence is run through all at once, so the bidirectional part works.
- I have found passing just the forward half of the hidden to work well,
decoder_hidden = encoder_hidden[:decoder.n_layers]
There’s a (not totally finished) version of the batched seq2seq model here: https://github.com/spro/practical-pytorch/blob/master/seq2seq-translation/seq2seq-translation-batched.ipynb
@spro Thanks for sharing the code.
I noticed in your current implementation of Attn Class, all sentences in the same batch has the same length for attention scores(attn_energies variable in your code). e.g. a sentence of length 3 will have a attention score vector of length 20 if it is organized in the same batch with a 20 words long sentence. This should not be a problem for training the model. But I think it might be troublesome when visualizing the attentions, since the attention scores for a short sentence won’t add up to one. Maybe passing the input_lengths variable to Attn?
As you said it would be great if other layers(nn.Softmax in this case) had support for PackedSequence.
As your answer here says, the hidden states of bidirectional RNN alternate between the layers.
decoder_hidden = encoder_hidden[:decoder.n_layers] really the forward half of the hidden?
Just get more confused.