nn.MultiheadAttention with triangular mask

Hi PyTorch family :grinning: :heart:

Currently trying to implement a traditional transformer decoder with triangular mask, to train on next token prediction task, such that the triangular mask both encodes padding mask AND triangular mask at once.

Reading the documentation for nn.MultiheadAttention says that the parameter attn_mask should be a boolean tensor which indicates True if the attention is forbidden between 2 tokens.

Hence, for an embedded tokens batch with 2 samples, 6 tokens and padding for the second one, I got the following mask :

mask = torch.tensor([[[1, 0, 0, 0, 0, 0],
         [1, 1, 0, 0, 0, 0],
         [1, 1, 1, 0, 0, 0],
         [1, 1, 1, 1, 0, 0],
         [1, 1, 1, 1, 1, 0],
         [1, 1, 1, 1, 1, 1]],

        [[1, 0, 0, 0, 0, 0],
         [1, 1, 0, 0, 0, 0],
         [1, 1, 1, 0, 0, 0],
         [1, 1, 1, 1, 0, 0],
         [1, 1, 1, 1, 1, 0],
         [0, 0, 0, 0, 0, 0]]])

But it gives me NaN embedded token for the last one of the second sample (padded token), with the following code :

import torch
import torch.nn as nn

embedded_tokens = torch.rand(2, 6, 512).transpose(0, 1)
decoder = nn.MultiheadAttention(512, 1)

attn_mask = ~mask.bool()
updated_tokens = decoder(embedded_tokens, embedded_tokens, embedded_tokens, attn_mask=attn_mask)

However, by allowing padded tokens to interact with itself, such that :

mask = torch.tensor([[[1, 0, 0, 0, 0, 0],
         [1, 1, 0, 0, 0, 0],
         [1, 1, 1, 0, 0, 0],
         [1, 1, 1, 1, 0, 0],
         [1, 1, 1, 1, 1, 0],
         [1, 1, 1, 1, 1, 1]],

        [[1, 0, 0, 0, 0, 0],
         [1, 1, 0, 0, 0, 0],
         [1, 1, 1, 0, 0, 0],
         [1, 1, 1, 1, 0, 0],
         [1, 1, 1, 1, 1, 0],
         [0, 0, 0, 0, 0, 1]]])

It works correctly and avoid the introduction of NaN.

Please, can you tell me if my implementation if correct ? If not, what should be the optimal way to implement it ?

Thank you guys so much for your help, cannot wait to have your answer and discuss about it with passionate PyTorch developers ! :heart: :grinning: :smiling_face_with_three_hearts: :heart_eyes: