How to use nn.TransformerDecoder() at inference time

Hello. I am using nn.TransformerDecoder() module to train a language model. During training time, the model is using target tgt and tgt_mask, so at each step the decoder is using the last true labels.
However, for text generation (at inference time), the model shouldn’t be using the true labels, but the ones he predicted in the last steps. Can we do that with nn.TransformerDecoder() ? Or should I reimplement the module to add the new predicted labels at each step ?

2 Likes

Were you able to get an answer on this? Is there an example?

An example with nn.Transformer module for word language model here.
Or a tutorial here

The example you posted is for the TransformerEncoder. I’m looking for an explicit usage example of the TransformerDecoder.
Can you provide one?

So for TransformDecoder, what’s your target and memory sequences?

Yes. You need to initialized the target with a tensor of SOS (start of sentence). At each step, you append this tensor with the pred, as explained in the paper Attention Is All You Need:

At each step the model is auto-regressive[10], consuming the previously generated symbols as additional input when generating the next

import torch
import torch.nn as nn

device = 'cpu'
d_model = 768
bs = 4
sos_idx = 0
vocab_size = 10000
input_len = 10
output_len = 12

# Define the model
encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=4).to(device)

encoder = nn.TransformerEncoder(encoder_layer=encoder_layer,
                                num_layers=6).to(device)

decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=4).to(device)

decoder = nn.TransformerDecoder(decoder_layer=decoder_layer,
                                num_layers=6).to(device)

decoder_emb = nn.Embedding(vocab_size, d_model)

predictor = nn.Linear(d_model, vocab_size)

# for a single batch x
x = torch.randn(bs, input_len, d_model).to(device)
encoder_output = encoder(x)  # (bs, input_len, d_model)

# initialized the input of the decoder with sos_idx (start of sentence token idx)
output = torch.ones(bs, output_len).long().to(device)*sos_idx
for t in range(1, output_len):
    tgt_emb = decoder_emb(output[:, :t]).transpose(0, 1)
    tgt_mask = torch.nn.Transformer().generate_square_subsequent_mask(
        t).to(device).transpose(0, 1)
    decoder_output = decoder(tgt=tgt_emb,
                             memory=encoder_output,
                             tgt_mask=tgt_mask)

    pred_proba_t = predictor(decoder_output)[-1, :, :]
    output_t = pred_proba_t.data.topk(1)[1].squeeze()
    output[:, t] = output_t

#output (bs, output_len)

My final model looks like this:

class Transformer(nn.Module):
    def __init__(self):
        pass
    
    def forward(self, x, y=None, mode='train'):
        encoder_output = self.encoder(x)

        if mode == 'train':
            tgt_emb = self.decoder_emb(y).transpose(0, 1)
            tgt_mask = torch.nn.Transformer().generate_square_subsequent_mask(
                tgt_emb.size(0)).to(self.device).transpose(0,1)
            decoder_output = self.decoder(tgt=tgt_emb,
                                          tgt_mask=tgt_mask,
                                          memory=encoder_output)
            return self.predictor(decoder_output).permute(1, 2, 0)
        elif mode == 'generate':
            # the solution I gave before
            return 
        

1 Like

Do you need tgt_mask in the inference? I am wondering is the for loop in inference processing do the same job as subsquent mask? Also, I am wondering have you done a experiment with nn.TransformerDecoder? I have some trouble using it, as when inference, it always return the same token. I really appreciate if you can help! Thank you!