nn.TransformerDecoder() - Outputs are empty

Hey all,

I’ve been struggling to get nn.TransformerDecoder() to work as intended - 1 of 2 things will happen - and both involve my output being empty.

My target data is as follows:

[['this', 'is', 'an', 'example', '[EOS]', '[PAD]'],
['another', 'target', 'sentence', '[EOS]', '[PAD]', '[PAD]']]

My encoders work fine, but the decoder seems to be playing up slightly. During/inside the forward pass, after encoding my source, I actually add an [SOS] token to the target (I’m doing Variational Inference and this is a strategy that one of the papers uses):

target_shifted = torch.cat((sos_token, target[:, :-1]), 1)

This is followed by generating some target masks:

trg_key_padding_mask = self.generate_pad_mask(target_shifted)
trg_mask = self.generate_square_subsequent_mask(target_shifted.size(-1))

I permute the embeddings so it’s [S, N, E], add my latent information to the SOS token and then run the data through the decoder:

target_embedding = self.get_embedding(target_shifted).permute(1, 0, 2) # [S, N, E]
target_embedding[0] = target_embedding[0] + z # z = latent information
decoder_outputs = self.transformer_decoder(target_embedding, encoder_outputs, tgt_mask=trg_mask, tgt_key_padding_mask=trg_key_padding_mask, memory_key_padding_mask=src_mask)

I then pass these outputs to a linear layer to calculate loss etc:

decoder_outputs = decoder_outputs.permute(1, 0, 2) # [N, S, E]
output = self.output(decoder_outputs)

Loss is as follows:

output = model(batch)
target = batch["target"]
loss = self.criterion(output.reshape(-1, output.size(-1)), target.reshape(-1)) # CrossEntropyLoss
  • With target shifting (i.e. target_shifted), my loss stays almost fixed/hovers around a non-zero value and I get no output (not even a repeated token) returned to me.
  • If I DON’T perform the target shifting, my loss goes to 0 and I still get no output returned to me.

Annoyingly, the code works as expected with my own implementation of MHA and a Transformer decoder - the issue occurs when using PyTorch’s nn.TransformerDecoder()