How to perform inference with nn.Transformer

I trained a Transformer model using the default nn.Transformer class to perform machine translation.

During training, we pass both the inputs into the encoder and the targets into the decoder. In pseudocode, a forward pass looks like:

for (inputs, targets) in train_loader:
    preds = transformer(src=inputs, tgt=targets)

My question is: what do we do with the tgt argument at test time? I have tried:

  • Feed in a boolean src_mask with True at each position
  • Feed in a tgt = torch.zeros_like(targets)

But these decrease the model performance.

Thank you for your help (:

At inference time, you are essentially generating the text. So, you will have to provide the currently generated text (beginning with start token) and then either greedily generate the next token or maintain a list of hypotheses.

Also, I am guessing you are providing more than just the input and target in your full code to account for masking (of keys associated with padding, looking at currently decoded positions, etc).