Missing or conflicting documentations between versions?

I don’t use Pytorch as often as I should, so I always need to consult the documentation. And now I came across an issue that was well documented in the previous versions but at least not in the current one (as far as I can tell).

The question is about handling the last hidden state h_n for an nn.LSTM layer (but same with nn.GRU). The issue is that the one dimension is the product of num_layers and num_directions. The documentation for Pytorch version 1.0.0 is pretty clear:

h_n of shape (num_layers*num_directions, batch, hidden_size): tensor containing the hidden state for t=seq_len. Like output, the layers can be separated using h_n.view(num_layers, num_directions, batch, hidden_size) and similarly for c_n.

It gives a concrete example of how to separate the num_layer and num_directions dimensions. And this is what I always used in my implementations. However, the documentation for Pytorch version 1.1.3 reads as follows:

h_n: tensor of shape (D*num_layers, H_out) for unbatched input or (D*num_layers, N, H_out) containing the final hidden state for each element in the sequence. When bidirectional=True, h_n will contain a concatenation of the final forward and reverse hidden states, respectively.

While mapping the names (D=num_directions, N=batch, H_out=hidden) is straightforward, the documentation is now missing the way to split D and num_layers. It’s tempting to adopt the old method:

h_n.view(num_layers, D, N, H_out)

but note that here the order of D and num_layers is now flipped: (D*num_layers, …) vs (num_layers*num_directions). Does this mean I now have to do:

h_n.view(D, num_layers, N, H_out)

I’m pretty sure the order does matter, but I derive for certain which version is the correct one. Or what am I missing here?

Since nobody replied and the question kind of popped up again, I’ve checked myself. In short, it has to be

h_n.view(num_layers, D, N, H_out)

Here’s a minimal example. First, let’s create an example batch as well as a nn.GRU layer, and push the batch through the layer.

import torch
import torch.nn as nn

torch.manual_seed(0)
torch.set_printoptions(precision=2)

# Define parameters
N, L, H_in, H_out, num_layers = 2, 4, 6, 3, 10
# Create random batch
batch = torch.rand(N, L, H_in)
# Create GRU layer (IMPORTANT: batch_first=True in this example!)
gru = nn.GRU(input_size=H_in, hidden_size=H_out, num_layers=num_layers, bidirectional=True, batch_first=True)
# Push batch through GRU layer
output, h_n = gru(batch)

Now let’s say we look at the hidden states of the first sequence item of the first sample with h_n[0][0]. This gives us:

[ 0.01, -0.02, -0.03, -0.27, 0.58, -0.45]

With [ 0.01, -0.02, -0.03] being the first hidden state of the forward direction and [ -0.27, 0.58, -0.45] the last hidden state of the backward direction. Similarly we can look at the hidden states of the last sequence item of the first sample with h_n[0][-1]:

[ 0.02, 0.05, -0.10, -0.05, 0.29, -0.18]

with [ 0.02, 0.05, -0.10] being the last hidden state of the forward direction, and [ -0.05, 0.29, -0.18] being the first hidden state of the backward direction.

Of course, in practice, we generally want to work with the last hidden states with respect to the forward and backward direction. These are are easily obtained from h_n:

h_n = h_n.view(num_layers, 2, N, H_out)
h_n[-1,:,0]

Here, the -1 is to get the final layer (in case num_layers>1); the 0 is again just refering to the first sample, the output is:

[[ 0.02, 0.05, -0.10],
 [-0.27, 0.58, -0.45]]

which are the 2 respective last hidden states as we had already seen above using output.

Again, I find it (a) odd that this information is now missing from the docs and (b) confusing that the docs regarding the output of h_n have actually changed from

(num_layers * num_directions, batch, hidden_size)

to

(D * num_layers, N, H_out​)

I would prefer to change the notation to (num_layers * D, N, H_out​). This is in line with the old docs and matches the required view() command better, particularly since this is now missing in the docs. Ideally, I think it should be there.

@ptrblck : Is there a way to get feedback on this? Maybe removing the view() command from the docs has a very good reason that I’m just missing.

I’ve provided some clarifications based on my understanding. Please let me know if this makes sense to you. Also thanks for really digging into this and I learned a lot from your diagram(s) above :slight_smile:

My interpretation: With [ 0.01, -0.02, -0.03] being the first hidden state of the forward direction and [ -0.27, 0.58, -0.45] the last hidden state of the backward direction for the first input sequence. Similarly we can look at the hidden states of the last sequence item of the first sample with h_n[0][-1]:

My interpretation: with [ 0.02, 0.05, -0.10] being the last hidden state of the forward direction, and [ -0.05, 0.29, -0.18] being the last hidden state of the backward direction. i.e. h_n always holds the last hidden state irrespective of which direction you’re moving in. i.e. this is necessary so that you can feed the h_n generated into the next forward() call to the LSTM layer in case you wish to maintain state across calls.

Yes, this seems consistent with my understanding of the hidden state returned from LSTM! Thanks for putting it crisply!

Thanks for the explanation and raising valid concerns! If I understand it correctly, the current order of dimensions in the docs causes confusion so (num_layers * D, N, H_out​) would better reflect the underlying layout?

Both (D * num_layers, N, H_out​) and (num_layers * D, N, H_out​) is fine, only the latter matches the old docs. Again, not a big deal, but it might cause less confusion since now the view() command to separate the directions and layers is missing. My suggestion would actually be adding the view() command again, e.g.:

h_n: tensor of shape (D*num_layers, H_out) for unbatched input or (D*num_layers, N, H_out) containing the final hidden state for each element in the sequence. When bidirectional=True, h_n will contain a concatenation of the final forward and reverse hidden states, respectively. Like output, the layers can be separated using h_n.view(num_layers, D, N, H_out) and similarly for c_n.

with the last sentence adopted from the old docs.

I think that would be very helpful particularly for beginners. Correctly handling the output of LSTMs/GRUs is not completely trivial as can be seen by the many related posts here.

1 Like