Hi,
I am training a transformer model to translate english into german.
The loss is going down per epoch.
But when I do the decoding I get an empty string always:
# x is the source sentence indices, batchsize x sequence_len
x = x.transpose(0,1)
encoder_emb = self.encoder_emb(x)
encoder_emb_pos = self.pos_encoder(encoder_emb) #.transpose(0, 1)
src_mask, src_padding_mask = self.create_mask_inference(x)
encoder_output = self.encoder(encoder_emb_pos, mask=src_mask, src_key_padding_mask=src_padding_mask)
output = torch.ones(1, x.size()[1]).fill_(self.vocab.bos_id()).long().cuda()
for t in range(1, self.output_len):
tgt_emb = self.decoder_emb(output) # [:, :t]
tgt_emb_pos = self.pos_encoder(tgt_emb) # .transpose(0, 1)
tgt_mask = torch.nn.Transformer().generate_square_subsequent_mask(t).cuda()
decoder_output = self.decoder(tgt=tgt_emb_pos, memory=encoder_output, tgt_mask=tgt_mask)
pred_proba_t = self.predictor(decoder_output)[-1, :, :]
_, next_token = torch.max(pred_proba_t, dim=1)
output = torch.cat([output, torch.unsqueeze(next_token, 0)], dim=0)
return output
Is there something wrong with the decoding code in the block above? Thank you for your help.