Clarification regarding the return of nn.GRU

The pytorch tutorial on seq2seq translation mentions something about the decoder. It says

In the simplest seq2seq decoder we use only last output of the encoder. This last output is sometimes called the context vector as it encodes context from the entire sequence. This context vector is used as the initial hidden state of the decoder.

The documentation of nn.GRU mentions that it returns → output, h_n

Which one these (output, h_n) signify context vector?

I am using GRU for classifying sentences between two classes, say between positive and negative class. Which one of (output, h_n) would be useful for this classification?

3 Likes

Short answer: You can use either.

For a bit more understanding, first check the output of output and h_n

  • output.shape = (seq_len, batch, num_directions * hidden_size), so you get a (context) vector for each item in your sequence
  • h_n.shape = (num_layers * num_directions, batch, hidden_size)

Now look at the following archtecture – it’s for an LSTM, for a GRU you can simple ignore the c_n:

In the basic case where num_layer=1 and bidirectional=False, and you don’t use PackedSequence, then output[-1]=h_n. In other words, h_n is the vector after the last layer and the last sequence item (bidirectional=True makes it more complicated).

Fore Seq2Seq, most people use h_n as context vector since the encoder and decoder often have symmetric architectures (e.g., same number of layers), so they can simple copy h_n between encoder to the decoder.

For classification, I usually use h_n instead of output. The reasons is – as far as I understand – then when you have batches with sequences of different lengths and padding, and use PackedSequence, output[-1] != h_n. It still works, of course, but I experience less accuracy.

9 Likes

Hi @vdw ,

Thank you very much for the amazing explanation and link to your code. Everything is clear now. I used both and both gave more or less same accuracy with h_n being marginally better.

Thanks again

1 Like

Great! I’m happy to help.

As an additional note: Once you use bidirectional=True you almost certainly want to always use h_n, since the last output of the reverse direction is in output[0] and not in output[-1]; see this older post of mine with the edit. h_n will give you the last outputs for both directions much more conveniently.

1 Like