Learn initial hidden state (h0) for RNN

Instead of randomly (or setting 0) initializing the hidden state h0, I want the model to learn the RNN hidden state by itself. According to this article Non-Zero Initial States for Recurrent Neural Networks, learning the initial state can speed up training and improve generalization.

Following this post, I set the initial hidden state as a parameter in the module:

self.word_lstm_init_h = Parameter(torch.randn(2, 20, word_lstm_dim), requires_grad=True).type(FloatTensor)
self.word_lstm_init_c = Parameter(torch.randn(2, 20, word_lstm_dim), requires_grad=True).type(FloatTensor)
self.word_lstm = nn.LSTM(word_lstm_input_dim, word_lstm_dim, 1,
                                 bidirectional=True, batch_first=True)

For the forward function:

word_lstm_input = torch.nn.utils.rnn.pack_padded_sequence(
    word_lstm_input, seq_len, batch_first=True
)
word_lstm_out, word_lstm_h = self.word_lstm(
    word_lstm_input, (self.word_lstm_init_h, self.word_lstm_init_c)
)
word_lstm_out, _ = torch.nn.utils.rnn.pad_packed_sequence(
    word_lstm_out, batch_first=True
)

I’m skip the loss function here. After loss.backward(), I try to print out gradients of the initial state.

print(model.word_lstm_init_h.grad)
print(model.word_lstm_init_c.grad)

It gives me “None” for the both.

Does nn.LSTM pass gradients to the initial hidden state?

1 Like

Your error is this line:

self.word_lstm_init_h = Parameter(torch.randn(2, 20, word_lstm_dim), requires_grad=True).type(FloatTensor)

What you did is x = foo.type(), The problem here is that foo is the leaf Variable (which will have .grad attribute correctly filled), not foo.type().

Try this:

self.word_lstm_init_h = Parameter(torch.randn(2, 20, word_lstm_dim).type(FloatTensor), requires_grad=True)
4 Likes

It works. Thank you!

Why do we need to set h_0 .Won’t it anyway be overwritten following the lstm calculation?

The hidden state will certainly be overwritten at each state. But you have to initialize an h_0 at the first state. Normally for each iteration, h_0 is randomly initialized. In here, I tried to set h_0 as a parameter and let the model learns an h_0 so that at each iteration, h_0 is fixed to a leaned value. The hypothesis is that the learned h_0 can either improve performance or speed up the training process.

Sorry to bring up an old post, but related question: I’m also trying to learn the initial hidden state for an LSTM, but using a separate context network to do so (as in the paper “Multiple Object Recognition with Visual Attention”). I thought that the output of the context vector being a non-leaf Variable would have the gradients propagate back through the LSTM to the context network, but the gradients of the parameters in the context network are still None. What would be the correct way to go about doing this?

I think there is an issue with the solution proposed here: we are fixing the model to accept only a batch size of 20, and we are learning separate values of h0 for each of the 20 batch elements.

This is how I solved it:

def __init__(...):
    # learn initial hidden state (h0, c0)
    # NOTE: use "2*" only if bidirectional
    h0 = torch.zeros(2*n_layers, 1, self.hidden_dim).to(device)
    c0 = torch.zeros(2*n_layers, 1, self.hidden_dim).to(device)
    nn.init.xavier_normal_(h0, gain=nn.init.calculate_gain('relu'))
    nn.init.xavier_normal_(c0, gain=nn.init.calculate_gain('relu'))
    self.h0 = nn.Parameter(h0, requires_grad=True)  # Parameter() to update weights
    self.c0 = nn.Parameter(c0, requires_grad=True)
    # initialize the other stuff...
def forward(self, seqs, lengths):
    packed = nn.utils.rnn.pack_padded_sequence(seqs, lengths)
    batch_size = seqs.shape[1]
    # repeat hid0 for batch while keeping the gradient update working for whole batch
    outputs, (hidden, cell) = self.lstm(packed, (self.h0.repeat(1, batch_size, 1),
                                                 self.c0.repeat(1, batch_size, 1)))
    # do more stuff / return ...

I also tried using expand() instead of repeat() to avoid creating copies. However these tensors need to be contiguous, so I think we can’t avoid the memory allocation (note that expand().contiguous() also creates a new copy). Both approaches seem to compute the gradients correctly, by considering the whole batch.

I’m a bit surprised that there isn’t a simpler way to do this (e.g. by providing the initial state h0 to the LSTM separately, for batch_size=1, and then feeding the whole batch without providing h0). If someone knows a cleaner solution please let me know.

2 Likes

Do we need to add h0 and c0 parameter to the optimizer’s list of parameters to optimize?
Adding to the list gives error: “ValueError: can’t optimize a non-leaf Tensor”
Thanks!

This definitely seems odd to me.
Additionally it almost seems wrong because you’re playing dynamically with the size of a param…

Nop, wrapping around Parameter is enough

The size of the parameter doesn’t change, you only have parameters for batch size=1. But when you create several copies of it (for each batch element), you compute the gradient relatively to each batch element. So, when you update the parameter, you do it using the average direction all the elements in the batch. I debugged my code to confirm this behaviour before posting here :slight_smile:

1 Like

Thanks @antoniogois,
Just curious, did you run any ablation study to see if it performed better?

Yeah, just a quick experiment. It showed a very modest gain on the devset (less than 1% accuracy I think), and it seemed more prone to overfitting after reaching that peak.
I’ve read a couple things arguing in favour of this approach, but in terms of code it seems everyone uses zeros as h0 (pytorch’s default behaviour)

2 Likes