The way to implement attention-mask/uni-direction attention in TransformerDecoder

Hi guys,

I’m learning about nn.Transformer in pytorch these days and I’m a bit confused about the implementation of the attention mask in decoder. Say we’re doing a machine translation task using Transformer, when inferencing, the output of each time step can only “see” the tokens before it. However, when training, we simply feed the correct sequence to the decoder and thus a token in the decoded sequence can see the tokens both before and after it, which I guess is not suitable for a robust model.

I know that GPT adopted single directional attention when decoding (and I guess it’s suitable for a decoder), but I’m wondering if the APIs in nn.Transformer can realize such feature (e.g. using attention masks for different tokens to avoid it from attending on later tokens)? If so, what should I do?

Any help would be appreciated, thanks in advance!

You should set up the attention mask for the decoder to mask the token from see the one before it.

Thanks for the reply! Could you please offer more concrete hints, e.g. ways to setup attention masks, sample codes or a link that provide these? Many thanks!

Many thanks for the code, it’s much clearer right now!