Since nobody replied and the question kind of popped up again, I’ve checked myself. In short, it has to be
h_n.view(num_layers, D, N, H_out)
Here’s a minimal example. First, let’s create an example batch as well as a nn.GRU layer, and push the batch through the layer.
import torch
import torch.nn as nn
torch.manual_seed(0)
torch.set_printoptions(precision=2)
# Define parameters
N, L, H_in, H_out, num_layers = 2, 4, 6, 3, 10
# Create random batch
batch = torch.rand(N, L, H_in)
# Create GRU layer (IMPORTANT: batch_first=True in this example!)
gru = nn.GRU(input_size=H_in, hidden_size=H_out, num_layers=num_layers, bidirectional=True, batch_first=True)
# Push batch through GRU layer
output, h_n = gru(batch)
Now let’s say we look at the hidden states of the first sequence item of the first sample with h_n[0][0]. This gives us:
[ 0.01, -0.02, -0.03, -0.27, 0.58, -0.45]
With [ 0.01, -0.02, -0.03] being the first hidden state of the forward direction and [ -0.27, 0.58, -0.45] the last hidden state of the backward direction. Similarly we can look at the hidden states of the last sequence item of the first sample with h_n[0][-1]:
[ 0.02, 0.05, -0.10, -0.05, 0.29, -0.18]
with [ 0.02, 0.05, -0.10] being the last hidden state of the forward direction, and [ -0.05, 0.29, -0.18] being the first hidden state of the backward direction.
Of course, in practice, we generally want to work with the last hidden states with respect to the forward and backward direction. These are are easily obtained from h_n:
h_n = h_n.view(num_layers, 2, N, H_out)
h_n[-1,:,0]
Here, the -1 is to get the final layer (in case num_layers>1); the 0 is again just refering to the first sample, the output is:
[[ 0.02, 0.05, -0.10],
 [-0.27, 0.58, -0.45]]
which are the 2 respective last hidden states as we had already seen above using output.
Again, I find it (a) odd that this information is now missing from the docs and (b) confusing that the docs regarding the output of h_n have actually changed from
(num_layers * num_directions, batch, hidden_size)
to
(D * num_layers, N, H_out)
I would prefer to change the notation to (num_layers * D, N, H_out). This is in line with the old docs and matches the required view() command better, particularly since this is now missing in the docs. Ideally, I think it should be there.
@ptrblck : Is there a way to get feedback on this? Maybe removing the view() command from the docs has a very good reason that I’m just missing.