GRU model not training properly

I’ve tried reimplementing a simple GRU language model using just a GRU and a linear layer, the code is at https://www.kaggle.com/alvations/gru-language-model-not-training-properly

But when the model is predicting, we see that it’s only predicting “the” and comma “,”.

Anyone spot something wrong with my code? Or hyperparameters?

I’ve replied to your Stackoverflow post. Maybe not a proper/full answer, but some food for thoughts.

1 Like

From @Chris’ Stackoverflow post:

I’m by no means a PyTorch expert, but that snippet looks fishy to me:

    # Put the embedded inputs into the GRU.
    output, hidden = self.gru(embedded, hidden)
    # Matrix manipulation magic.
    batch_size, sequence_len, hidden_size = output.shape
    # Technically, linear layer takes a 2-D matrix as input, so more manipulation...
    output = output.contiguous().view(batch_size * sequence_len, hidden_size)
  • When GRU is not instantiated with batch_first=True, then the output shape is (seq_len, batch, num_directions * hidden_size) – not that seq_len and batch_size are flipped. For the view command it actually doesn’t technically matter, but that’s my main issue here.
  • view(batch_size * sequence_len, hidden_size) doesn’t look right at all. Say you start with a batch of size 32, but after that you have size of 32*seq_len. Usually, only the output of the last step is used (or the average or the max over all steps)

Something like this should work:

    # Put the embedded inputs into the GRU.
    output, hidden = self.gru(embedded, hidden)
    # Not needed, just to show the true output shape order
    seq_len, batch_size, hidden_size = output.shape
    # Given the shape of output, this is the last step
    output = output[-1]
    # output.shape = (batch_size, hidden_size) <-- What you want

Two personal words of warning:

  • view() is a dangerous command! PyTorch or any other framework only throws errors when the dimensions of the tensors do not match up. But just because the dimensions fit after view() does not mean the reshaping was done correctly, i.e., that the values are in the right spot of the output tensor. For example, if you have to flatten a shape (seq_len, batch_size, hidden_size) to (batch_size, seq_len*hidden_size), you cannot simply do view(batch_size, -1), but first have to do transpose(1,0) to get a shape of (batch_size, seq_len, hidden_size). With out without transpose(), view() will work and the dimensions will be correct. But only with transpose(), the values are at the right position after view()
  • Since this is such an easy mistake to make, I saw many examples on GitHub and such where in my opinion it’s no done correctly. The problem is that the network often still learns something. In short, I’m not much more careful when looking and adopting code snippets and the view() command is in my opinion the biggest trap.

If it helps, here’s the forward method of a GRU classifier network:

def forward(self, batch, method='last_step'):
    embeds = self.word_embeddings(batch)
    x = torch.transpose(embeds, 0, 1)
    x, self.hidden = self.gru(x, self.hidden)

    if method == 'last_step':
        x = x[-1]
    elif method == 'average_pooling':
        x = torch.sum(x, dim=0) / len(batch[0])
    elif method == 'max_pooling':
        x, _ = torch.max(x, dim=0)
    else:
        raise Exception('Unknown method.')
    # A series of Linear layers with ReLU and Dropout
    for l in self.linears:
        x = l(x)
    log_probs = F.log_softmax(x, dim=1)
    return log_probs

Sorry, it’s a little tough to comment on StackOverflow so I’m going to discuss the answer here.

As for PyTorch, the default for batch_first is already set to true, from https://pytorch.org/docs/stable/nn.html#gru so this line should be right:

batch_size, sequence_len, hidden_size = output.shape

I totally agree that .view() is sort of dangerous but it was the fastest hack to modify the shape to what I’ll need.

As for this line:

output = output.contiguous().view(batch_size * sequence_len, hidden_size)

The hidden_size is what’s declared in the linear initialization, self.classifier = nn.Linear(hidden_size, vocab_size) which does make sense since the feedforward should be taking the hidden_size of the GRU as the input and spitting out the vocabulary size for the language modelling task.

My suspicion is when I try to restructure the self.classifer(output), I might not be using the right shape and I should be re-permuting to the right dimensions instead.


BTW, I’ve updated the code on https://www.kaggle.com/alvations/gru-language-model-not-training-properly and it looks like with very very careful tuning, I’m able to get the model to generate something sort of meaningful.


hyperparams = Hyperparams(embed_size=250, hidden_size=250, num_layers=1,
                          loss_func=nn.CrossEntropyLoss,
                          learning_rate=0.0003, optimizer=optim.Adam, batch_size=200)

dataloader, model, optimizer, criterion = initialize_data_model_optim_loss(hyperparams)

train(5000, dataloader, model, criterion, optimizer)

generate_example(model)

[out]:

the null hypothesis is never true . </s>

But it still befalls me why would that happen when the hyperparams are “suitable” and trained long enough.

Huh, that’s written in the docs for GRU:

  • batch_first – If True , then the input and output tensors are provided as (batch, seq, feature). Default: False

That’s exactly my concern I have with many codes I see: just tweak so PyTorch won’t complain :). Your self.classifier = nn.Linear(hidden_size, vocab_size) will take anything as input that is of shape (x, hidden_size) where x can basically anything since it’s treated as the batch size.

Did you try output = output[-1] and check the shape afterwards? It should be (batch_size, hidden_size) and raise no error when calling forward(). Please not that this only works if batch_first=False, otherwise you need to transpose() first or something (I just saw that you use batch_first=True now)

What does the network trying to train in layman terms. both Kaggle links have no further comments or descriptions. How does an individual training data item looks like? Is X a sequence of words and y the next predicted word?

Ah, whoops I really need to get more sleep =)
Yes, in the kaggle code, I’ve explicitly put batch_first=True now.

In layman terms, the network is trying to train to generate a sequence of word.

E.g. given the sentence: <s> this is a foo bar sentence </s>

X: <s> this is a foo bar sentence
Y: this is a foo bar sentence </s>

Let me check the output size again.

With batch_first=True, the output size of the foward() is (batch_size * sequence_len, hidden_size). The output[-1] is just the (hidden_size).

I’ve tried the code snippet you’ve posted for


    def forward(self, batch, method='last_step', use_softmax=False):
        embeds = self.embedding(batch)
        x = torch.transpose(embeds, 0, 1)
        x, self.hidden = self.gru(x, self.hidden)

        if method == 'last_step':
            x = x[-1]
        elif method == 'average_pooling':
            x = torch.sum(x, dim=0) / len(batch[0])
        elif method == 'max_pooling':
            x, _ = torch.max(x, dim=0)
        else:
            raise Exception('Unknown method.')
        x = self.classifier(x)
        return (F.softmax(x,dim=1), self.hidden) if use_softmax else (x, self.hidden)

It looks like it’s predicting just a single label for each of the data point in the batch.

The expected in-/output should be e.g. the max length of the sequence is 5 and batch size is 3

x = [[1, 2 , 3 ,4], [5, 6, 7 ,0], [8, 9, 0, 0]] 

y = [[2 , 3 ,4, 0], [6, 7 ,0, 0], [9, 0, 0, 0]] 

Hm, I’ve never seen a language model being set up like this then – obviously doesn’t mean it’s wrong, though!

Most of the time X is a sequence of current words and y being the next word, i.e., a many-to-one model. Here, the batch size never changes. Now I can kind of see that your code combines this to a many-to-many.

I managed to solve the problem. It’s the last year softmax causing instability… Got to check gradients… gradients at the last layer becomes really small due to the softmax, switching it off works.

Pushing the latest notebook in a few minutes.

Here it is, version 11 should work. https://www.kaggle.com/alvations/gru-language-model

Thanks Chris for looking into the problem!

Nice! No problem, learned something new myself here. Happy coding!

1 Like