Hi. Let me share my insights. Please correct me if I’m wrong. Long but hopefully useful post coming.
Let’s start with PyTorch’s TransformerEncoder. According to the docs, it says
forward(src, mask=None, src_key_padding_mask=None). Also it says that the mask’s shape is (S,S), the source sequence length, and that it is additive.
Note: The (S,S) and additive parts are found in the Transformer class docs, not in the TransformerEncoder class docs. Similar thing happens with the Decoder docs.
So, in this case, the
src_mask is a square matrix (S,S) that adds right after the
Q*Kt part of the attention mechanism (before the Values (
V) multiplication and
softmax operations). This way, if we add zeros, there is no effect over the subsequent Values (V) multiplication and/or softmax operations. However, if we add a -inf, then most likely the softmax operation will result in zero, whatsoever the V values are. Why would we want to do this? If we do this, this effect will carry on to the top of the TransformerEncoder effectively asking it to predict only by seeing past info (tokens). As an interesting note, this is exactly what the Transformer’s static method called
generate_square_subsequent_mask does, i.e. generate a square matrix with zeros and -inf. Lastly, the
src_key_padding_mask does something utterly similar but in order to mask the padding values which have (clearly) no semantic meaning, i.e. please do not pay attention (attend) to the bunch of zeros that were appended to the right of the different length sentences.
Having understood the previous ideas, the following is easier to understand.
Let’s now take the TransformerDecoder class. According to the docs, it says
forward(tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None).
Let’s take slowly it by parts.
tgt_mask behaves identically to the
src_mask with shape (T,T) (which is simply the target sequence length). So, now you must know that its purpose is to mask some desired tokens of the target sentences inside the attention mechanism, and of course, in our scenario, to mask all the way the future target tokens.
memory_mask is interesting. Although it’s also another masking mechanism as we know, what’s interesting is its shape. According to docs, its shape is (T,S). Since we now know that S is the source sequence length and T is the target sequence length, this means that the masking mechanism takes place right on top of the encoder’s outputs which are feed into each of the decoder’s inputs. This mask need not be square, for example if we use different sequence lengths between source and target sentences (other words when if S!=T). Why would we want to use this one? I have always regarded it as a memory bandwidth. Suppose you apply a lot of -inf, then a lot of -inf will be fed into the decoder’s encoder-decoder attention, effectively making it to not pay attention to a lot of tokens!
Thirdly, and somewhat easy to understand, we also once again have padding masking arguments:
memory_key_padding_mask=None, which you must know that they allow to make the transformer’s life and grads easier by telling it to not attend to zeros both padded to the right in the memory and target sequences.
Lastly, as per your question, the official example is not the best in terms of self-explanations. I have to read countless other resources in order to reach this level of understanding, where sometimes I often confuse myself still! According to the docs, src, tgt and memory masks have no dimension as the batch’s dimension. Only the padding ones, which can be confusing. About this, all I can say is my opinion: I think that Pytorch is designed to disregard the batch size, so for example the result of some operation over a batch with shape (10,10) will be the same as (1,1,1,1,10,10) as it’ll disregard whatever the batches dimension are and will only work with what is has to be worked on, i.e. in the example the last two dimensions. I can only say this after much experimentation and practice.
Hope this helps a bit with your question.