When training a language model, if an entire sequence is feed into lstm layer, will teacher forcing (the ground truth label at current time step is used as input for the next time step) be implemented implicitly? I tried to search for the answer in pytorch docs, but couldn’t find it. Only saw one guy posted on stack overflow saying that
If this is true, to make predictions without teacher forcing, it seems we have to iterate through the sequence one by one and use the output at current time step to be the input at next time step. But isn’t this very inefficient?
Maybe I misunderstand your problem. Teacher Forcing is usually applied to the decoder in case of Sequence-to-Sequence models, where you generate, say, a sentence. For example, the prediction of the 4th word depends on the prediction of the 3rd word (no teacher forcing) or the ground truth of the 3rd word (teacher forcing).
A language model is usually not a Sequence-to-Sequence model but more like a Sequence-to-NextWord model, basically a simply classifier. So you don’t have a decoder where Teacher Forcing is applicable. I don’t see any sense in applying Teacher Forcing to the encoder, i.e., the RNN for the input sequence.
Thanks a lot for the reply! Actually it’s an image captioning problem and I am talking about decoder. Sorry for the misunderstanding!
Ah, OK…got it. Yes, in this case you use the LSTM as decoder.
Anyway, generating the words step by step is the way for any X-to-sequence model. You can, of course, during training time give the whole sequence to the LSTM in case of Teacher Forcing. But training the whole network using only Teacher Forcing gives you, I think, poor results. Teacher Forcing is only applied with a certain probability (e.g., 50%) since it has been shown to make the training more stable and faster.
During inference time you have generate your output sequence anyway step by step in a loop. I don’t that the loop is the bottleneck. Firstly, the heavy lifting is still done by the LSTM, and given a whole sequence to the LSTM just wraps the loop, but it’s still there.
All RNN-based decoders I have seen so far have the loop in their
forward() method to process the words.
Found an implementation in this pytorch tutorial demonstrating it.