How to use nn.TransformerDecoder() at inference time

Yes. You need to initialized the target with a tensor of SOS (start of sentence). At each step, you append this tensor with the pred, as explained in the paper Attention Is All You Need:

At each step the model is auto-regressive[10], consuming the previously generated symbols as additional input when generating the next

import torch
import torch.nn as nn

device = 'cpu'
d_model = 768
bs = 4
sos_idx = 0
vocab_size = 10000
input_len = 10
output_len = 12

# Define the model
encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=4).to(device)

encoder = nn.TransformerEncoder(encoder_layer=encoder_layer,
                                num_layers=6).to(device)

decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=4).to(device)

decoder = nn.TransformerDecoder(decoder_layer=decoder_layer,
                                num_layers=6).to(device)

decoder_emb = nn.Embedding(vocab_size, d_model)

predictor = nn.Linear(d_model, vocab_size)

# for a single batch x
x = torch.randn(bs, input_len, d_model).to(device)
encoder_output = encoder(x)  # (bs, input_len, d_model)

# initialized the input of the decoder with sos_idx (start of sentence token idx)
output = torch.ones(bs, output_len).long().to(device)*sos_idx
for t in range(1, output_len):
    tgt_emb = decoder_emb(output[:, :t]).transpose(0, 1)
    tgt_mask = torch.nn.Transformer().generate_square_subsequent_mask(
        t).to(device).transpose(0, 1)
    decoder_output = decoder(tgt=tgt_emb,
                             memory=encoder_output,
                             tgt_mask=tgt_mask)

    pred_proba_t = predictor(decoder_output)[-1, :, :]
    output_t = pred_proba_t.data.topk(1)[1].squeeze()
    output[:, t] = output_t

#output (bs, output_len)

My final model looks like this:

class Transformer(nn.Module):
    def __init__(self):
        pass
    
    def forward(self, x, y=None, mode='train'):
        encoder_output = self.encoder(x)

        if mode == 'train':
            tgt_emb = self.decoder_emb(y).transpose(0, 1)
            tgt_mask = torch.nn.Transformer().generate_square_subsequent_mask(
                tgt_emb.size(0)).to(self.device).transpose(0,1)
            decoder_output = self.decoder(tgt=tgt_emb,
                                          tgt_mask=tgt_mask,
                                          memory=encoder_output)
            return self.predictor(decoder_output).permute(1, 2, 0)
        elif mode == 'generate':
            # the solution I gave before
            return 
        

6 Likes