Implementing Transformer Decoder for Machine Translation

Hi,

I am not understanding how to use the transformer decoder layer provided in PyTorch 1.2 for autoregressive decoding and beam search.

In LSTM, I don’t have to worry about masking, but in transformer, since all the target is taken just at once, I really need to make sure the masking is correct. Clearly the masking in the below code is wrong, but I do not get any shape errors, code just runs but…

The below code just leads to perfect perplexity in the case of a transformer decoder.

        mask = torch.zeros(max_length)

        for t in range(1, max_length):
            mask[t-1] = 1.0
            # Call Decoder
            output = self.decoder(input_tensor, src_encodings, mask)
            # Get only the first output
            output = output[t, :, :]
            # Do Softmax
            output_dist = F.softmax(output.squeeze(0), dim=1)
            # Get mle score
            mle_scores += self.criterion(output.squeeze(0), input_tensor[t]) * (input_lengths > t).float()

What should I change here, how do do the correct maskingt? Please note that the decoder is just the transformer decoder layer (https://pytorch.org/docs/stable/_modules/torch/nn/modules/transformer.html) + linear softmax over vocabulary.

Any help would be greatly appreciated :slight_smile:

You can pass the mask as tgt_mask in nn.TransformerDecoder.

thanks, i am already passing the mask (the transformer decoder is just many layers of the decoder i am using: https://pytorch.org/docs/stable/_modules/torch/nn/modules/transformer.html)

i am not getting how to construct the mask?

should i prepare a mask for each pass of the decoding?

No, the decoder layers shares the mask.

1 Like

thanks, yes, but i am passing the wrong mask to a single decoder layer, how should i create masks for autoregressive decoding…

Any update on this issue? If so, can you share an example?

1 Like

You can use the built-in function:
tgt_mask = torch.nn.Transformer().generate_square_subsequent_mask(length)
Where length is the length of the target (tgt)

1 Like