How can I know which part of *h_n* of bidirectional RNN is for backward process?

Hi, I’m building a model using Bidirectional GRU, so I use nn.GRU(bidirectional=True)

From the doc, I got it’s outputs are

Outputs: output, h_n
    - **output** (seq_len, batch, hidden_size * num_directions): tensor containing the output features h_t from
      the last layer of the RNN, for each t. If a :class:`torch.nn.utils.rnn.PackedSequence` has been given as the
      input, the output will also be a packed sequence.
    - **h_n** (num_layers * num_directions, batch, hidden_size): tensor containing the hidden state for t=seq_len

And the first dim of h_n is all GRU Cell. Because in the model, there are (num_layers * num_directions) GRUCell.
But I wonder h_n is like
[
layer0_forward
layer0_backward
layer1_forward
layer1_backward
layer2_forward
layer2_backward

] or
[
layer0_forward
layer1_forward
layer2_forward
layer0_backward
layer1_backward
layer2_backward

]

Is anybody know it? Or how I can figure it out?

2 Likes

After checking GRU implement, I think I can probably find answer from the implement of self._backend.RNN

        func = self._backend.RNN(
            self.mode,
            self.input_size,
            self.hidden_size,
            num_layers=self.num_layers,
            batch_first=self.batch_first,
            dropout=self.dropout,
            train=self.training,
            bidirectional=self.bidirectional,
            batch_sizes=batch_sizes,
            dropout_state=self.dropout_state
        )

But how can see these code??

Your first guess is correct, alternating between the layers.

2 Likes

Alright, but I still have a question between the last time step output of GRU and GRU’s final state.
They are expected to be the same, right? I print them out and find them not really the same.
I run the code below.

# coding:utf-8
import torch
import torch.nn as nn
import numpy as np
from torch.autograd import Variable

seq_len = 6
batch_size = 1
input_size = 1
hidden_size = 10
num_layers = 3
batch_first = True
bidirectional = True

torch.manual_seed(1)

# **input** (seq_len, batch, input_size)
x = torch.rand(batch_size, seq_len, input_size)
x_rev = torch.from_numpy(np.flip(x.numpy(), 1).copy())
# (num_layers * num_directions, batch, hidden_size)
if bidirectional:
    num_directions = 2
else:
    num_directions = 1
h0 = torch.rand(num_layers*num_directions, batch_size, hidden_size)

rnn = nn.GRU(input_size, hidden_size, num_layers, 
             batch_first=batch_first, bidirectional=bidirectional)
print rnn

out, ht = rnn(Variable(x), Variable(h0))

print out.data.size()
print ht.data.size()

assert out.data.numpy().shape == (batch_size, seq_len, hidden_size*num_directions)
assert ht.data.numpy().shape == (num_layers*num_directions, batch_size, hidden_size)

print "output:"
print out.data.numpy()[0, seq_len-1, :] # [hidden_size*num_directions]
print "===================================="
print "ht:"
print ht.data.numpy()[:, 0, :] # [num_layers*num_directions, hidden_size]

So I get the output like this:

>>>

GRU(1, 10, num_layers=3, batch_first=True, bidirectional=True)
(1L, 6L, 20L)
(6L, 1L, 10L)
output:
[ **0.04090527  0.04467951  0.06184166 -0.119278   -0.07899605 -0.17775261**
 **-0.25711796 -0.0560216  -0.06801324 -0.62566853**  0.09493496 -0.00143968
  0.25473037  0.59195685  0.08295314  0.61662054  0.39969781  0.52175015
  0.43700069 -0.04902107]
====================================
ht:
[[ 0.41622654  0.05891414  0.24079823 -0.20317592  0.20570976  0.07495184
   0.31944707 -0.3336893  -0.17610091 -0.01868644]
 [ 0.188013   -0.27898508  0.13432087 -0.079565    0.19181061 -0.28547999
  -0.19238529  0.08653103 -0.33994722  0.12975907]
 [-0.1610465  -0.1817638  -0.07482101 -0.04572783  0.27683198  0.16544969
   0.10135207 -0.43468314 -0.46809191 -0.00571362]
 [-0.27692401 -0.04289184  0.14566612  0.12111901  0.12315567  0.35866803
   0.0838761  -0.08178325  0.40468279 -0.1950635 ]
[ **0.04090527  0.04467951  0.06184166 -0.119278   -0.07899605 -0.17775261**
**-0.25711796 -0.0560216  -0.06801324 -0.62566853**]
 [ 0.04620094 -0.34189698  0.08069657  0.39240748 -0.09260736  0.61043888
   0.26960379  0.2404768  -0.13964601  0.07339926]]

As we can see, output[:hidden_size] = ht[4]
BUT output[hidden_size:] =
0.09493496 -0.00143968
0.25473037 0.59195685 0.08295314 0.61662054 0.39969781 0.52175015
0.43700069 -0.04902107

cannot match in ht.

I check the doc

Outputs: output, h_n
    - **output** (seq_len, batch, hidden_size * num_directions): tensor containing the output features h_t from
      the last layer of the RNN, for each t. If a :class:`torch.nn.utils.rnn.PackedSequence` has been given as the
      input, the output will also be a packed sequence.
    - **h_n** (num_layers * num_directions, batch, hidden_size): tensor containing the hidden state for t=seq_len

and don’t know what I’ve missed. We’ve got the final step forward output in h_n but why not the final step backward output in h_n?? h_n should have num_directions final hidden state right?

1 Like

I’m observing a similar issue. Only half of the elements of the final timestep of the output match hidden state.

1 Like

I believe this is answered here:

1 Like

@scoinea Yes Thanks!