I’ve been struggling to get nn.TransformerDecoder() to work as intended - 1 of 2 things will happen - and both involve my output being empty.
My target data is as follows:
[['this', 'is', 'an', 'example', '[EOS]', '[PAD]'], ['another', 'target', 'sentence', '[EOS]', '[PAD]', '[PAD]']]
My encoders work fine, but the decoder seems to be playing up slightly. During/inside the forward pass, after encoding my source, I actually add an [SOS] token to the target (I’m doing Variational Inference and this is a strategy that one of the papers uses):
target_shifted = torch.cat((sos_token, target[:, :-1]), 1)
This is followed by generating some target masks:
trg_key_padding_mask = self.generate_pad_mask(target_shifted) trg_mask = self.generate_square_subsequent_mask(target_shifted.size(-1))
I permute the embeddings so it’s
[S, N, E], add my latent information to the SOS token and then run the data through the decoder:
target_embedding = self.get_embedding(target_shifted).permute(1, 0, 2) # [S, N, E] target_embedding = target_embedding + z # z = latent information decoder_outputs = self.transformer_decoder(target_embedding, encoder_outputs, tgt_mask=trg_mask, tgt_key_padding_mask=trg_key_padding_mask, memory_key_padding_mask=src_mask)
I then pass these outputs to a linear layer to calculate loss etc:
decoder_outputs = decoder_outputs.permute(1, 0, 2) # [N, S, E] output = self.output(decoder_outputs)
Loss is as follows:
output = model(batch) target = batch["target"] loss = self.criterion(output.reshape(-1, output.size(-1)), target.reshape(-1)) # CrossEntropyLoss
- With target shifting (i.e.
target_shifted), my loss stays almost fixed/hovers around a non-zero value and I get no output (not even a repeated token) returned to me.
- If I DON’T perform the target shifting, my loss goes to 0 and I still get no output returned to me.
Annoyingly, the code works as expected with my own implementation of MHA and a Transformer decoder - the issue occurs when using PyTorch’s nn.TransformerDecoder()