I trained a Transformer model using the default nn.Transformer class to perform machine translation.
During training, we pass both the inputs into the encoder and the targets into the decoder. In pseudocode, a forward pass looks like:
for (inputs, targets) in train_loader:
preds = transformer(src=inputs, tgt=targets)
My question is: what do we do with the tgt
argument at test time? I have tried:
- Feed in a boolean
src_mask
withTrue
at each position - Feed in a
tgt = torch.zeros_like(targets)
But these decrease the model performance.
Thank you for your help (: