Question about Encoder

In The Annotated Encoder-Decoder, the encoder module is this:

class Encoder(nn.Module):
    """Encodes a sequence of word embeddings"""
    def __init__(self, input_size, hidden_size, num_layers=1, dropout=0.):
        super(Encoder, self).__init__()
        self.num_layers = num_layers
        self.rnn = nn.GRU(input_size, hidden_size, num_layers, 
                          batch_first=True, bidirectional=True, dropout=dropout)
    def forward(self, x, mask, lengths):
        Applies a bidirectional GRU to sequence of embeddings x.
        The input mini-batch x needs to be sorted by length.
        x should have dimensions [batch, time, dim].
        packed = pack_padded_sequence(x, lengths, batch_first=True)
        output, final = self.rnn(packed)
        output, _ = pad_packed_sequence(output, batch_first=True)

        # we need to manually concatenate the final states for both directions
        fwd_final = final[0:final.size(0):2]
        bwd_final = final[1:final.size(0):2]
        final =[fwd_final, bwd_final], dim=2)  # [num_layers, batch, 2*dim]

        return output, final

and I don’t undersstand the last comment in the forward method, regarding the final hidden states. What is going on there, and why would I need to do that?


Because the encoder is bidirectional. This means that one RNN is trained in one direction across your inputs, and the other is trained in the opposite direction. In order to represent each timestep of your RNN as informed by both what comes before, and what comes after, we represent each state as the concatenation of both of those RNN’s at that timestep.

e.g. if we have the phrase ‘the quick brown fox’, we encode it in both directions as:
[the quick, brown, fox]
[fox, brown, quick, the]
Where each word carries with it information that came from the previous words. After, and to the point of your question, the states representing these words need to be concatenated pairwise. Because you are taking only the final state, you concatenate just those two RNN states, which have information about everything in the sequence going left-to-right, AND right-to-left.

1 Like

Thank you very much! :slight_smile:

1 Like