Understanding mask size in Transformer Example


i am trying to understand the Transformer architecture, following one of the pytorch examples at (Language Modeling with nn.Transformer and TorchText — PyTorch Tutorials 1.11.0+cu102 documentation)

I have troubles thought to understand the dimension/shape of the mask that is used to limit the self-attention to sequence elements before the “current” token.

In the example, the mask size is [batch_size, batch_size]. I would have thought it would be something like [sequence_length, sequence_length]. So for each position in the sequence, there is a separate mask that indicates what other tokens the self-attention mechanism can “access”.

Running the code, I see that the additiative mask has the shape [BS, BS] with the content

tensor([[0., -inf, -inf, -inf],
        [0., 0., -inf, -inf],
        [0., 0., 0., -inf],
        [0., 0., 0., 0.]])

using a BS of 4.

Can someone maybe clarify this for me? Are different masks applied for the different batches? And even more general, if I have a BS of 4, and sequences of length 6 for example, during training is the model not supposed to learn the probability of the N+1 sequence token based on all previous N tokens? So for each input token position there exists its own mask vector?

thank you for the clarification,

Hi. Let me share my insights. Please correct me if I’m wrong. Long but hopefully useful post coming.

Let’s start with PyTorch’s TransformerEncoder. According to the docs, it says forward(src, mask=None, src_key_padding_mask=None). Also it says that the mask’s shape is (S,S), the source sequence length, and that it is additive.

Note: The (S,S) and additive parts are found in the Transformer class docs, not in the TransformerEncoder class docs. Similar thing happens with the Decoder docs.

So, in this case, the src_mask is a square matrix (S,S) that adds right after the Q*Kt part of the attention mechanism (before the Values (V) multiplication and softmax operations). This way, if we add zeros, there is no effect over the subsequent Values (V) multiplication and/or softmax operations. However, if we add a -inf, then most likely the softmax operation will result in zero, whatsoever the V values are. Why would we want to do this? If we do this, this effect will carry on to the top of the TransformerEncoder effectively asking it to predict only by seeing past info (tokens). As an interesting note, this is exactly what the Transformer’s static method called generate_square_subsequent_mask does, i.e. generate a square matrix with zeros and -inf. Lastly, the src_key_padding_mask does something utterly similar but in order to mask the padding values which have (clearly) no semantic meaning, i.e. please do not pay attention (attend) to the bunch of zeros that were appended to the right of the different length sentences.

Having understood the previous ideas, the following is easier to understand.

Let’s now take the TransformerDecoder class. According to the docs, it says forward(tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None).

Let’s take slowly it by parts.

First, the tgt_mask behaves identically to the src_mask with shape (T,T) (which is simply the target sequence length). So, now you must know that its purpose is to mask some desired tokens of the target sentences inside the attention mechanism, and of course, in our scenario, to mask all the way the future target tokens.

Second, the memory_mask is interesting. Although it’s also another masking mechanism as we know, what’s interesting is its shape. According to docs, its shape is (T,S). Since we now know that S is the source sequence length and T is the target sequence length, this means that the masking mechanism takes place right on top of the encoder’s outputs which are feed into each of the decoder’s inputs. This mask need not be square, for example if we use different sequence lengths between source and target sentences (other words when if S!=T). Why would we want to use this one? I have always regarded it as a memory bandwidth. Suppose you apply a lot of -inf, then a lot of -inf will be fed into the decoder’s encoder-decoder attention, effectively making it to not pay attention to a lot of tokens!

Thirdly, and somewhat easy to understand, we also once again have padding masking arguments: tgt_key_padding_mask and memory_key_padding_mask=None, which you must know that they allow to make the transformer’s life and grads easier by telling it to not attend to zeros both padded to the right in the memory and target sequences.

Lastly, as per your question, the official example is not the best in terms of self-explanations. I have to read countless other resources in order to reach this level of understanding, where sometimes I often confuse myself still! According to the docs, src, tgt and memory masks have no dimension as the batch’s dimension. Only the padding ones, which can be confusing. About this, all I can say is my opinion: I think that Pytorch is designed to disregard the batch size, so for example the result of some operation over a batch with shape (10,10) will be the same as (1,1,1,1,10,10) as it’ll disregard whatever the batches dimension are and will only work with what is has to be worked on, i.e. in the example the last two dimensions. I can only say this after much experimentation and practice.

Hope this helps a bit with your question.