[solved] Train initial hidden state of RNNs

I want to have an RNN with an initial state h_0 that is trainable. Other packages such as Lasagne allow it via a flag. I implemented the following:

class EncoderRNN(nn.Module):
def __init__(self, input_size, hidden_size, n_layers=1):
    super(EncoderRNN, self).__init__()
    self.n_layers = n_layers
    self.hidden_size = hidden_size

    hidden0 = Variable(torch.zeros(1, 1, hidden_size), requires_grad=True)
    if use_cuda:
        self.hidden0 = hidden0.cuda()
    else:
        self.hidden0 = hidden0

    self.embedding = nn.Linear(input_size, hidden_size)
    self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)

def forward(self, input, hidden):
    output = self.embedding(input).view((1, 1, -1))
    for i in range(self.n_layers):
        output = F.relu(output)
        output, hidden = self.gru(output, hidden)
    return output, hidden

Despite setting requires_grad=True, the parameter hidden0 does not appear in the model’s parameter list. How can I force pyTorch to train h_0?

After searching, I found the solution to my question using nn.Parameter as follows:

class EncoderRNN(nn.Module):
def __init__(self, input_size, hidden_size, n_layers=1):
    super(EncoderRNN, self).__init__()
    self.n_layers = n_layers
    self.hidden_size = hidden_size

    hidden0 = torch.zeros(1, 1, hidden_size)
    if use_cuda:
        hidden0 = hidden0.cuda()
    else:
        hidden0 = hidden0

    self.hidden0 = nn.Parameter(hidden0, requires_grad=True)
    self.embedding = nn.Linear(input_size, hidden_size)
    self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)

def forward(self, input, hidden):
    output = self.embedding(input).view((1, 1, -1))
    for i in range(self.n_layers):
        output = F.relu(output)
        output, hidden = self.gru(output, hidden)
    return output, hidden

Of course, we can make initialization a bit prettier.

1 Like

it does not look like you have used self.hidden0 in your forward computation. It’s not clear how it will learn self.hidden0 unless you are supplying it as an argument in forward, but why not just do: output, hidden = self.gru(output, self.hidden0)?

I use it when I pass through the entire sequence at the beginning of the sequence.

The piece of code shown here is not the unrolled version, it is more like a single cell.

Yeah this doesn’t seem to be what you want to do.

I think what you want is what is done here in the original torch impl of an NTM.


You want to initialize your memory matrix with a vanilla variable (either normal distribution or all constant values).
Then pass it through a linear layer. You will thereby learn a layer that can initialize your hidden state.

I completely agree with your revision but on the other hand do you know whether the following code would init only the first hidden state or all hidden states in the RNN layer. I would like to init only the first hidden layer and have the rest of them be adjusted according to the learning process (so the require_grad=True I believe), but I’m afraid that I will end up initing all hidden states by using this statement

(I don’t provide here all the code)
Class Encoder

def __init__(self, rnn_type, ntoken, emb_dim, nhid, nlayers, dropout=0.5):

    hidden0 = torch.Tensor(ntoken, emb_dim) # size
    nn.init.uniform(hidden0)
    self.hidden0 = nn.Parameter(hidden0, requires_grad=False)

def forward(self, input, hidden):

    output, hidden = self.rnn(input, self.hidden0)

    output = self.drop(output)
    decoded1 = self.decoder1(output.view(output.size(0)*output.size(1), output.size(2)))
    decoded2 = self.decoder2(decoded1)

    return decoded2.view(output.size(0), output.size(1), decoded2.size(1)), hidden

Does this do what I would like?

That would use the initial hidden state for every call of the forward pass regardless of the value of the hidden argument passed to forward.

How about this?

def forward(self, input, hidden):
    if hidden is None:
        hidden = self.hidden0
    output, hidden = self.rnn(input, self.hidden0)
    ...

And your training loop would look a little like this… (assuming you pass the data to your Encoder one timestep at a time).

hidden = None
for timestep in inputs:
    output, hidden = model(input, hidden)
    ...

I figured it out myself:
All I needed to do is:

def init_hidden(self, bsz):
    weight = next(self.parameters()).data
    a = weight.new(self.nlayers, bsz, 1).normal_(-1,1)
    b = weight.new(self.nlayers, bsz, self.nhid-1).zero_()
    return Variable(torch.cat([a,b], 2))

Thanks for the help

Hi Taha,

so is it correct to say at training time you did this?

encoder = EncoderRNN(n_words, hidden_size, n_layers)

outputs, hidden = encoder(input_batches, input_lengths, hidden=encoder.hidden0)

Also, how did you make the initla hidden state work for different batch sizes?

hidden0 = torch.zeros(n_layers, 1, hidden_size)

Thank you!

I use tensor.repeat() to create replicates for each sample in the batch.

In particular, given hidden0 = torch.zeros(n_layers, 1, hidden_size), I use hidden0.repeat(1, B, 1) where B is the batch size.

3 Likes