Decoder always predicts the same token

I have the following decoder that after a few steps only predicts the EOS token. Overfitting on a dummy, tiny dataset is impossible because of this so it seems that there is a big error in the code.

  (embedding): Embeddings(
    (word_embeddings): Embedding(30002, 768, padding_idx=3)
    (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.5, inplace=False)
  (ffn1): FFN(
    (dense): Linear(in_features=768, out_features=512, bias=False)
    (layernorm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.5, inplace=False)
    (activation): GELU()
  (rnn): GRU(512, 512, batch_first=True, bidirectional=True)
  (ffn2): FFN(
    (dense): Linear(in_features=1024, out_features=512, bias=False)
    (layernorm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.5, inplace=False)
    (activation): GELU()
  (selector): Sequential(
    (0): Linear(in_features=512, out_features=30002, bias=True)
    (1): LogSoftmax(dim=-1)

The forward is relatively straightforward (see what I did there?): pass the input_ids to the embedding and a FFN, then use that representation in the RNN with the given sembedding as initial hidden state. Pass the output through another FFN and do softmax. Return logits and last hidden states of the RNN. In the next step, use those hidden states as the new hidden states, and the highest predicted token as the new input.

def forward(self, input_ids, sembedding):
    embedded = self.embedding(input_ids)
    output = self.ffn1(embedded)
    output, hidden = self.rnn(output, sembedding)
    output = self.ffn2(output)
    logits = self.selector(output)

    return logits, hidden

sembedding is the initial hidden_state for the RNN. This is similar to an encoder-deocder architecture only here we do not train the encoder but we do have access to pretrained encoder representations.

In my training loop I start off each batch with a SOS token and feed every top predicted token to next step until target_len is reached. I also swap randomly between teacher forced training.

def step(self, batch, teacher_forcing_ratio=0.5):
    batch_size, target_len = batch["input_ids"].size()[:2]
    # Init first decoder input woth SOS (BOS) token
    decoder_input = torch.tensor([[self.tokenizer.bos_token_id]] * batch_size).to(self.device)
    batch["input_ids"] = batch["input_ids"].to(self.device)

    # Init first decoder hidden_state: one zero'd second embedding in case the RNN is bidirectional
    decoder_hidden = torch.stack((batch["sembedding"],
                                 ).to(self.device) if self.model.num_directions == 2 \
        else batch["sembedding"].unsqueeze(0).to(self.device)

    loss = torch.tensor([0.]).to(self.device)

    use_teacher_forcing = random.random() < teacher_forcing_ratio
    # contains tuples of predicted and correct words
    tokens = []
    for i in range(target_len):
        # overwrite previous decoder_hidden
        output, decoder_hidden = self.model(decoder_input, decoder_hidden)
        batch_correct_ids = batch["input_ids"][:, i]

        # NLLLoss compute loss between predicted classes (bs x classes) and correct classes for _this word_
        # set to ignore the padding index
        loss += self.criterion(output[:, 0, :], batch_correct_ids)

        batch_predicted_ids = output.topk(1).indices.squeeze(1).detach()

        # if use teacher training: use current correct word for next prediction
        # else do NOT use teacher training: us current predction for next prediction
        decoder_input = batch_correct_ids.unsqueeze(1) if use_teacher_forcing else batch_predicted_ids

    return loss, loss.item() / target_len

I also clip the gradients after each step:

clip_grad_norm_(self.model.parameters(), 1.0)

At first subsequent predictions are already relatively identical, but after a few iterations there’s a bit more variation. But relatively quickly ALL predictions turn into other words (but always the same ones), eventually turning into EOS tokens (edit: after changing the activation to ReLU, another token is always predicted - it seems like a random token that always gets repeated). Note that this already happens after 80 steps (batch_size 128).

I found that the returned hidden state of the RNN contains a lot of zeros. I am not sure if that is the problem but it seems like it could be related.

tensor([[[  3.9874e-02,  -6.7757e-06,   2.6094e-04,  ...,  -1.2708e-17,
            4.1839e-02,   7.8125e-03],
         [ -7.8125e-03,  -2.5341e-02,   7.8125e-03,  ...,  -7.8125e-03,
           -7.8125e-03,  -7.8125e-03],
         [ -0.0000e+00, -1.0610e-314,   0.0000e+00,  ...,   0.0000e+00,
            0.0000e+00,   0.0000e+00],
         [  0.0000e+00,   0.0000e+00,   0.0000e+00,  ...,   0.0000e+00,
           -0.0000e+00,  1.0610e-314]]], device='cuda:0', dtype=torch.float64,

I have no idea what might be going wrong although I suspect that the issue is rather with my step than with the model. I already tried playing with the learning rate, disabling some layers (LayerNorm, dropout, ffn2), using pretrained embeddings and freezing or unfreezing them, and disabling teacher forcing, using bidrectional vs unidirectional GRU. The end result is always the same.

If you have any pointers, that would be very helpful. I have googled many things concerning neural networks always predicting the same item and I have tried all the suggestions that I could find. Any new ones, no matter how crazy, are welcome!

Not really any solution – I cannot see any obvious issues in your code – just some ideas you might want to consider. It essentially aims to simplify to a more tried and tested decoder architecture:

  • I’ve never seen a linear layer between an embedding and an RNN layer. Not sure what this means semantically, but I would skip that one for a start.

  • It don’t think it makes sense that the decoder is bidirectional. Since you feed it one token/word step by step, the output should be the same anyway. I would remove the stacking to remove any potential point of error.

  • I would start with the basic decoder (without attention; not relevant for you anyway) of the basis Seq2Seq tutorial. Again, just because it’s tried and tested…including by myself :).

  • To accommodate the Seq2Seq tutorial, I would also work with batches of size 1. You seem to use larger batches, but that’s not obivous to do for a decoder. I only do it when I know that that my target sequences in a batch have the same lengths (e.g., for an autoencoder).

In short, I would start with the most basic architecture/setup using tested code (snippets) as far as possible (and batch_size=1) to get it working (= easy overfitting on a small dataset). Once that seems fine, I add complexity to improve performance and/or accuracy – or to modify the basic model to my specific use case.

Thanks. I started from the seq2seq tutorial, but somewhere along the way things apparently went wrong. Unofrtunately I hadn’t put my code into version control yet so it is hard to traceback my steps. As a last resort I’ll try again from scratch.

  • In my case I need a linear between the embedding and the RNN because the embeddings are pretrained and sembedding are pretrained, but they are not the same size. So I need a linear transformation from on to the other
  • you are right, yet disabling bidrectionality does not change anything
  • from what I have read, I think batched decoding is not an issue: in your loss function you just ignore the index for the padding token. Yes, you have some computational overhead for items in the batch that have already reached their EOS, but that’s generally worth it compared to the gains that you get from batch_sizes of 64 and more. It does require more memory, though

Thanks for your input!

In my case the issue appeared to be that the dtype of the initial hidden state was a double and the input was a float. I don’t quite understand why that is an issue, but casting the hidden state to a float solved the issue. If you have any intuition about why this might be a problem for PyTorch, do let me know in the comments.

1 Like

Interesting issue. Unfortunately, I have no idea, but good to know in the future.

I posted a separate issue about this in case anyone has an idea. Is it required that input and hidden for GRU have the same dtype? It seems that PyTorch should at least give a warning if the hidden state should be float32.

If you use a batch size of one, do you still run optimizer.step after every step (so after every single datapoint) or do you accumulate?