I am trying to train a Transformer model using (Transformer — PyTorch 1.9.0 documentation) which will take in an input of shape (batch_size, 73, 640) and output a sequence of shape (batch_size, 2, 640). I have initialized the transformer with d_model to 640 and kept the rest at default. Now while training, the model requires two inputs: src and tgt. I can safely assume src is the input that I have (batch_size, 73, 640) and the tgt would be the ground truths that I have (batch_size, 2, 640). However, while in the validation/test phase of the model, I may not have ground truths for prediction. What should be done in this case? Or am I understanding the tgt incorrectly?
Any help would be greatly appreciated.
Well, I would recommend you to look for hands on code to understand the concept.
Let me first to pose the problem (I think ) you are talking about, which is predicting sequences given a context.
Imagine you have N elements. When you train, you assume that for each element the network knows n_i and all the previous elements.
For example: to predict n3 you know n2 n1 and n0.
To predict n4 you know n3 n2 n1 and n0.
You achieve this by using masks and all this happens in parallel (thanks to the masks too).
On contrary, with lstms you need to predict element by element.
However, for testing the problem is different. Instead of doing all the predictions in parallel thanks to the masks, you will have to run the transformer element by element.
This is, you 1st predict n0. Then feed n0 to get n1. Then feed n0 and n1 and so on.
Thank you. I am getting the gist of the idea even though I still have some questions. So what I understand, I should be generating masks for src and tgt both so that it would hide the elements ahead (not yet predicted).
Well it depends on the field and the task. You have to think in the real pipeline.
For example, generating speech from text.
Do you have access to the whole sentence? Yes, thus you don’t need to mask that.
Do you have access to the whole speech? No, because you are predicting it. Thus you have to mask it.
The key is you (usually) don’t need masks at inference, as you predict one element and feed that and all the previous.
But the best you can do is looking for code that always help. No matter it is not using pytorch’s transformer, the key idea is the same.