RNN: output vs hidden state don't match up (my misunderstanding?)

EDIT: I think found my problem. When I do output_last_step = output[-1] I get the last hidden states w.r.t. the forward pass and not the backward pass. The last hidden state w.r.t. the to the backward pass is part of output[0]. self.hidden is independent from seq_len contains only the last hidden states for both passes. I got confused by the figure since it is only for the unidirectional case. If anyone could confirm this, even better.

According to the following figure, I can get h_n^w either from the output of a LSTM (as the output of the last time step n) or the hidden state (for the last layer w):

However, I fail to reproduce it in case of bidirectional LSTMs/GRUs (unidirectional works); below I provide the minimal example. An output will always look like this:

The following 2 tensors should be equal(?)
tensor([ 0.0283,  0.4886,  0.0097, -0.1139,  0.0552, -0.0287])
tensor([ 0.0283,  0.4886,  0.0097, -0.2138,  0.1073, -0.0219])

i.e., the first 3 values representing the forward direction are indeed identical. The values for the backward direction don’t match up. In the unidirectional case, both tensors are the same (obviously with only 3 values).

Moreover, when I do a print(self.hidden[0].data) to check the complete hidden state, the values -0.1139, 0.0552, -0.0287 are nowhere to be found. Using GRU yields the same issue

What am I missing here? Do I misunderstand the return values of LSTM/GRU? Do I make any mistakes in calculating h_n^w using the output or the hidden state?

Here’s the working code:

import torch
import torch.nn as nn


class RnnClassifier(nn.Module):

    def __init__(self, bidirectional=True):
        super(RnnClassifier, self).__init__()
        self.bidirectional = bidirectional

        self.embed_dim = 5
        self.hidden_dim = 3
        self.num_layers = 4

        self.word_embeddings = nn.Embedding(100, self.embed_dim)
        self.num_directions = 2 if bidirectional == True else 1
        self.rnn = nn.LSTM(self.embed_dim, self.hidden_dim, num_layers=self.num_layers, bidirectional=bidirectional)
        self.hidden = None


    def init_hidden(self, batch_size):
        return (torch.zeros(self.num_layers * self.num_directions, batch_size, self.hidden_dim),
                torch.zeros(self.num_layers * self.num_directions, batch_size, self.hidden_dim))


    def forward(self, inputs):
        batch_size, seq_len = inputs.shape
        # Push through embedding layer and transpose for RNN layer (batch_first=False)
        X = self.word_embeddings(inputs).transpose(0, 1)
        # Push through RNN layer
        output, self.hidden = self.rnn(X, self.hidden)
        # output.shape = (seq_len, batch_size, num_directions*hidden_dim)
        # self.hidden[0].shape = (num_layers*num_directions, batch_size, hidden_dim)

        # Get h_n^w directly from output of the last time step
        output_last_step = output[-1] # (batch_size, num_directions*hidden_dim)

        # Get h_n^w from hidden state
        hidden = self.hidden[0].view(self.num_layers, self.num_directions, batch_size, self.hidden_dim)
        hidden_last_layer = hidden[-1] # (num_directions, batch_size, hidden_dim)

        if self.bidirectional:
            direction_1, direction_2 = hidden_last_layer[0], hidden_last_layer[1]
            direction_full = torch.cat((direction_1, direction_2), 1)
        else:
            direction_full = hidden_last_layer.squeeze(0)

        print("The following 2 tensors should be equal(?)")
        print(output_last_step[0].data)
        print(direction_full[0].data)

        print(self.hidden[0].data)

if __name__ == '__main__':

    model = RnnClassifier(bidirectional=True)

    inputs = torch.LongTensor([[1, 2, 4, 6, 4, 2, 3]])

    model.hidden = model.init_hidden(inputs.shape[0])

    model(inputs)
2 Likes

Thanks for the updated post. I am building a binary classification model for sequences using the following code.

class RNNClassifier(nn.Module):

    def __init__(self):
        super(RnnClassifier, self).__init__()

        ...
        self.hidden = None
        ...
    
    def forward(self, inputs):
        ...
        output, self.hidden = self.rnn(X, self.hidden)
        ...
        
        #self.hidden[0][-1] is passed a FC layer and finally sigmoid act
        
    #assuming one layer and one direction
    def init_hidden(self, batch_size):
        return (torch.zeros(1, batch_size, self.hidden_dim),
            torch.zeros(1, batch_size, self.hidden_dim)) 


model = RNNClassifier()
model.hidden = model.init_hidden(batch_size)

for X,y in train_ldr:
    yhat = model(input)
    loss = criterion(...,...)
    optimizer.zero_grad()
    loss.backward()
    optimized.step()

The error I get is:

“RunTimeError: Trying to backward through the graph a second time, but the buffers have
already been freed. Specify retain_graph=True when calling backward the first time.”

It seems like this post is trying to address the issue –


The hidden.detach_(); hidden = hidden.detach() did not work for me and I am not even sure
where to put it.

Did you run in to this issue? You will need to go through a couple of mini-batches
in the first epoch before you run into this.

I’m not sure why detach() is not working. However, if your sentences are not depending on each other, you can simply call init_hidden for each batch. Call init_hidden directly before model(input)

vdw I can confirm that output[0] contains the last possible computed value in the reverse direction of the bi-directional LSTM.

Additionally, the hidden state variable is laid out so that every alternate element is from the forward and reverse passes respectively. For example, if you consider batched hidden state of shape (D x num_layers, N, Hout), then the following elements are the hidden states from the forward direction:

h[0], h[2], h[4], etc…

and the following outputs are for the reverse direction:

h[1], h[3], h[5], etc…

Basically, I want to clarify that the hidden states aren’t stacked on top of each other, but are layered like a puff pastry.so that every alternate hidden state comes from the forward and reverse direction respectively. This is why you can retrieve the last hidden state of the forward and reverse directions using h[-2] and h[-1] respectively.

I always still adhere the old documation that specifies how to split the D and num_layers dimensions

h_n = h_n.view(num_layers, num_directions, batch, hidden_size)

but I’m not sure if this is true, since the old docs also say that the output shape of h_n is (num_layers * num_directions, batch, hidden_size).

The latest docs give a shape of (num_directions x num_layers, batch, hidden_size) – note the switch in order of num_directions and num_layers. Unfortunately, the docs no longer specify the view() commend to correctly separate them. I’ve even made a post about it, but nobody replied.

In any case, the latest docs state that

h_n will contain a concatenation of the final forward and reverse hidden states

So h_n[0] will contain the forward and reverse hidden state for all sequences in you batch at position 0 – that is, the first hidden state of the forward direction and the last hidden state of the backward direction.

In other words, you cannot use the index like you described to split into forward and backward direction.

@ vdw Here’s some code I wrote to try and reverse engineer the contents of the output and hidden states of the LSTM layer in PyTorch. I’ve provided my interpretation based on what I see, and would like to know if that interpretation sounds reasonable to you.

hs = []
ys = []
for bid in (False, True):
    torch.manual_seed(21)
    x = torch.randn(10, 5, 32)
    lstm = nn.LSTM(32, 64, 3, dropout=0.5, bidirectional=bid, batch_first=True)
    y, (h, c) = lstm(x)
    D = 2 if bid else 1
    print(f"Bidirectional: {bid}", x.shape, y.shape, h.shape, c.shape)
    # h = h.reshape(2, 3, )
    hs.append(h)
    
    y = y.view(10, 5, D, 64)
    print(y.shape)
    ys.append(y)
# end for

for (i, j) in [(0, 0), (0, 1), (1, 0), (1, 1)]:
    # print(f"{i}, {hs[i].size(0)}")
    for x in range(hs[i].size(0)):
        for l in (0, -1):
            for d in range(ys[j].size(2)):
                # print(x)
                tagi, tagj = "uni" if i == 0 else "bi", "uni" if j == 0 else "bi"
                if torch.allclose(hs[i][x][0], ys[j][0][l][d]):
                    print(f"All Close hs[{tagi}][DL={x}][N=0], ys[{tagj}][N=0][L={l}][D={d}]")
                          
print("Done with checking hs and ys")

Print the following (newlines added for clarity).

Bidirectional: False torch.Size([10, 5, 32]) torch.Size([10, 5, 64]) torch.Size([3, 10, 64]) torch.Size([3, 10, 64]) torch.Size([10, 5, 1, 64])
Bidirectional: True torch.Size([10, 5, 32]) torch.Size([10, 5, 128]) torch.Size([6, 10, 64]) torch.Size([6, 10, 64]) torch.Size([10, 5, 2, 64])

# For unidirectional LSTM
All Close hs[uni][DL=2][N=0], ys[uni][N=0][L=-1][D=0]

# For bidirectional LSTM
All Close hs[bi][DL=4][N=0], ys[bi][N=0][L=-1][D=0]
All Close hs[bi][DL=5][N=0], ys[bi][N=0][L=0][D=1]

Done with checking hs and ys

My interpretation:

[1] Unidirectional case: This is simple. The hidden state at the last layer (index=2 for 3 layer LSTM) is the same as output[-1] which is the last output going from left to right.

[2] Bidirectional case: This one is more complicated. For the unidirectional case, the hidden state at index=4 (or index=-2 for a 3 layer LSTM) is the same as the last output from left to right (L=-1) in the forward direction. The hidden state at index=5 (or index=-1 for a 3 layer LSTM) is the same as the first output (L=0) in the right to left case if you reshape the output by the direction dimension.

@dhruvbird Motivated by your post I’ve actually checked and replied to my old post. Maybe you can have a look to see if this makes sense.

The important bit is that when bidirectional=True, you get both hidden states for a sequence item in concatenated form: first the one for the forward direction, then the one for the backward direction. This then also implies that the respective last hidden states are on opposite ends. Again, I give a concrete example in my reply linked above.

It’s not really complicated as it is consistent and intuitive. But, yes, the bidirectional case requires a bit more care to “extract” the correct bits of h_n to be used for further layers. This also depends what you’re actually trying yo train.

Thank you! Commented on that post! :slight_smile: