I am not understanding how to use the transformer decoder layer provided in PyTorch 1.2 for autoregressive decoding and beam search.
In LSTM, I don’t have to worry about masking, but in transformer, since all the target is taken just at once, I really need to make sure the masking is correct. Clearly the masking in the below code is wrong, but I do not get any shape errors, code just runs but…
The below code just leads to perfect perplexity in the case of a transformer decoder.
mask = torch.zeros(max_length) for t in range(1, max_length): mask[t-1] = 1.0 # Call Decoder output = self.decoder(input_tensor, src_encodings, mask) # Get only the first output output = output[t, :, :] # Do Softmax output_dist = F.softmax(output.squeeze(0), dim=1) # Get mle score mle_scores += self.criterion(output.squeeze(0), input_tensor[t]) * (input_lengths > t).float()
What should I change here, how do do the correct maskingt? Please note that the decoder is just the transformer decoder layer (https://pytorch.org/docs/stable/_modules/torch/nn/modules/transformer.html) + linear softmax over vocabulary.
Any help would be greatly appreciated