I am trying to understand how masking works with the scaled_dot_product_attention, I’m using the one implemented in torch.nn.functional.scaled_dot_product_attention
So I wanted to test how the masks work, I create three tensors to simulate the queries, keys and values. In this setting we have a batch size of 3, a sequence length of 7 and a d_model of 5. I want to stick with a very basic setting to understand the output. I create the atto_mask as it is done in the documentation, but I also want to simulate the case where the first and last sentences have two padding tokens (at the end) and the middle sentence only has one padding token (also at the end).
random_q = torch.rand(3, 7, 5) random_k = torch.rand(3, 7, 5) random_v = torch.rand(3, 7, 5) attn_mask = torch.ones(3, 7, 7, dtype=torch.bool).tril(diagonal=0) attn_mask[0, -2:] = False attn_mask[1, -1:] = False attn_mask[2, -2:] = False torch.nn.functional.scaled_dot_product_attention(random_q, random_k, random_v, attn_mask)
We notice that when we apply the scaled dot product attention that the rows associated with the padding tokens have all nan values. Which makes sense because when the mask is applied it is supposed to mask these tokens which makes the attention scores (the logins before applying a softmax) -infinity.
I wanted to know how these rows are dealt with later on within the Transformer module for example. I couldn’t figure that out from the source code so I’m trying to reach out to the experts here.
Thank you in advance!