Yes. You need to initialized the target with a tensor of SOS (start of sentence). At each step, you append this tensor with the pred, as explained in the paper Attention Is All You Need:
At each step the model is auto-regressive[10], consuming the previously generated symbols as additional input when generating the next
import torch
import torch.nn as nn
device = 'cpu'
d_model = 768
bs = 4
sos_idx = 0
vocab_size = 10000
input_len = 10
output_len = 12
# Define the model
encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=4).to(device)
encoder = nn.TransformerEncoder(encoder_layer=encoder_layer,
num_layers=6).to(device)
decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=4).to(device)
decoder = nn.TransformerDecoder(decoder_layer=decoder_layer,
num_layers=6).to(device)
decoder_emb = nn.Embedding(vocab_size, d_model)
predictor = nn.Linear(d_model, vocab_size)
# for a single batch x
x = torch.randn(bs, input_len, d_model).to(device)
encoder_output = encoder(x) # (bs, input_len, d_model)
# initialized the input of the decoder with sos_idx (start of sentence token idx)
output = torch.ones(bs, output_len).long().to(device)*sos_idx
for t in range(1, output_len):
tgt_emb = decoder_emb(output[:, :t]).transpose(0, 1)
tgt_mask = torch.nn.Transformer().generate_square_subsequent_mask(
t).to(device).transpose(0, 1)
decoder_output = decoder(tgt=tgt_emb,
memory=encoder_output,
tgt_mask=tgt_mask)
pred_proba_t = predictor(decoder_output)[-1, :, :]
output_t = pred_proba_t.data.topk(1)[1].squeeze()
output[:, t] = output_t
#output (bs, output_len)
My final model looks like this:
class Transformer(nn.Module):
def __init__(self):
pass
def forward(self, x, y=None, mode='train'):
encoder_output = self.encoder(x)
if mode == 'train':
tgt_emb = self.decoder_emb(y).transpose(0, 1)
tgt_mask = torch.nn.Transformer().generate_square_subsequent_mask(
tgt_emb.size(0)).to(self.device).transpose(0,1)
decoder_output = self.decoder(tgt=tgt_emb,
tgt_mask=tgt_mask,
memory=encoder_output)
return self.predictor(decoder_output).permute(1, 2, 0)
elif mode == 'generate':
# the solution I gave before
return