Something wrong with this official Pytorch tutorial?

I followed the Language Modeling with nn.Transformer and TorchText tutorial exactly, but it doesn’t seem to work at all. After three epochs of training (as is in the tutorial code), the model’s loss is at about 4.0 and the model’s prediction (even on the training set) are gibberish:

# predict ONE batch

j = 0

src_mask = generate_square_subsequent_mask(bptt).to(device)
data, targets = get_batch(train_data, j) # i кратно bptt
output = model(data, src_mask)


i = 6
out = torch.argmax(output, axis=2)[:, i].tolist()
string = ' '.join(vocab.lookup_tokens(out))
print(i)
print(string)

out = targets.reshape(35, 20)[:, i].tolist()
string = ' '.join(vocab.lookup_tokens(out))
print(i)
print(string)

out = data[:, i].tolist()
string = ' '.join(vocab.lookup_tokens(out))
print(i)
print(string)

6
<unk> <unk> <unk> <unk> the <unk> <unk> of <unk> to of the <unk> <unk> the was a a the <unk> <unk> <unk> <unk> <unk> of <unk> to <unk> up to <unk> <unk> the <unk> to
6
<unk> was <unk> for a big loss <unk> losing some of the <unk> that he had gained with the previous play . <unk> state was unable to pick up another first down and sent in
6
<unk> <unk> was <unk> for a big loss <unk> losing some of the <unk> that he had gained with the previous play . <unk> state was unable to pick up another first down and sent

Even if I train the model for 300 epochs, the loss doesn’t go below about 3.6, and the model’s predictions even on the training set are gibberish. So the question is:

  • Is there a bug in the tutorial code?
  • Is is true that this model just can’t be expected to produce any sensible predictions, and is only intended to just illustrate how to build a minimalistic encoder-only transformer?
    It seems to me that the second answer is more likely, as no example prediction (after training) is provided in that tutorial (as, for example is the case for this other tutorial (LANGUAGE TRANSLATION WITH NN.TRANSFORMER AND TORCHTEXT, in which a transformer (although a different architecture) trained for a few epochs does translate a phrase from German into English. Any help is much appreciated!

I would suspect that something is up. While the general methodology seems more or less sane (though word splitting is usually replaced by subword splitting these days), it would seem that one would need quite a bit more training to get anywhere. A notebook that does a very similar thing on a char-level basis (much “easier” to train) is in A. Karpathy’s minGPT: minGPT/play_char.ipynb at master · karpathy/minGPT · GitHub That trains a while (40 minutes for me) and then gives output that vaguely looks like Shakespeare…

Best regards

Thomas

1 Like

UPDATE

I just trained it on my mid-level GPU for about 3 more hours

| end of epoch 186 | time: 71.50s | valid loss 5.71 | valid ppl 302.70


| epoch 187 | 200/ 2928 batches | lr 0.02 | ms/batch 23.74 | loss 3.86 | ppl 47.48
| epoch 187 | 400/ 2928 batches | lr 0.02 | ms/batch 23.63 | loss 3.87 | ppl 47.92
| epoch 187 | 600/ 2928 batches | lr 0.02 | ms/batch 23.60 | loss 3.74 | ppl 41.98
| epoch 187 | 800/ 2928 batches | lr 0.02 | ms/batch 23.67 | loss 3.82 | ppl 45.44
| epoch 187 | 1000/ 2928 batches | lr 0.02 | ms/batch 23.65 | loss 3.84 | ppl 46.41
| epoch 187 | 1200/ 2928 batches | lr 0.02 | ms/batch 23.65 | loss 3.84 | ppl 46.75
| epoch 187 | 1400/ 2928 batches | lr 0.02 | ms/batch 23.63 | loss 3.82 | ppl 45.52
| epoch 187 | 1600/ 2928 batches | lr 0.02 | ms/batch 23.57 | loss 3.87 | ppl 47.81
| epoch 187 | 1800/ 2928 batches | lr 0.02 | ms/batch 23.61 | loss 3.87 | ppl 47.74
| epoch 187 | 2000/ 2928 batches | lr 0.02 | ms/batch 23.65 | loss 3.86 | ppl 47.52
| epoch 187 | 2200/ 2928 batches | lr 0.02 | ms/batch 23.64 | loss 3.72 | ppl 41.37
| epoch 187 | 2400/ 2928 batches | lr 0.02 | ms/batch 23.63 | loss 3.80 | ppl 44.64
| epoch 187 | 2600/ 2928 batches | lr 0.02 | ms/batch 23.62 | loss 3.82 | ppl 45.52

and it seems to improve:

# predict ONE batch

j = 0

src_mask = generate_square_subsequent_mask(bptt).to(device)
data, targets = get_batch(train_data, j) # i кратно bptt
output = model(data, src_mask)

11
, , the ’ the by the ’ s brother , and his doyle that was no evidence to the that be s him the jokes for the brother words jokes — s . with
11
authorities . hornung was angered by doyle ’ s action , and told him there was no need for him to ’ butt in ’ except for his own ’ satisfaction ’ . relations between
11
military authorities . hornung was angered by doyle ’ s action , and told him there was no need for him to ’ butt in ’ except for his own ’ satisfaction ’ . relations

The top paragraph is the prediction on a sample from the training set (targets and inputs are the 2nd and third paragraphs, respectively). Maybe on the Wikitext2 this model must be trained for days? But then the learning rate should decay much slower. Please help.

It is depending on your hardware, but I would expect training to take a few days, yes.

The perhaps most classic reference with WikiText2 language modelling is Regularizing and Optimizing LSTM Language Models | OpenReview , so I would compare validation perplexities to that to see how you are doing.

Best regards

Thomas

1 Like