Understanding the padding mask for Transformers

For purely educational purposes, my goal is to implement basic Transformer architecture from scratch. So far I focused on the encoder for classification tasks and assumed that all samples in a batch have the same length. This means, I didn’t care about any masking.

However, now I want to support masking. I like to think that I understand the the purpose of, e.g., the target mask so the order cannot “peak into the future”. I generate this mask as follows:

source_batch = torch.LongTensor([
    [1, 2, 3, 0, 0, 0],
    [1, 2, 3, 4, 5, 6],
    [1, 2, 3, 4, 5, 0]
])

batch_size, seq_len = source_batch.shape

def generate_tgt_mask(size):
    return torch.triu(torch.ones(seq_len, seq_len) * float('-inf'), diagonal=1)

print(generate_tgt_mask(seq_len))

yielding:

tensor([[0., -inf, -inf, -inf, -inf, -inf],
        [0.,   0., -inf, -inf, -inf, -inf],
        [0.,   0.,   0., -inf, -inf, -inf],
        [0.,   0.,   0.,   0., -inf, -inf],
        [0.,   0.,   0.,   0.,   0., -inf],
        [0.,   0.,   0.,   0.,   0.,   0.]])

which should be the expected outcome when I check the PyTorch docs. This mask has a shape of (L,L) where L is the sequence length of the source or target sequence. Again, this matches the docs.

I use this mask in my implementation of the Scaled Dot Product Attention as follows:

class Attention(nn.Module):
    ### Implements Scaled Dot Product Attention
    
    def __init__(self):
        super().__init__()


    def forward(self, Q, K, V, mask=None, dropout=None):
        # All shapes: (batch_size, seq_len, hidden_size)
        
        # Perform Q*K^T (* is the dot product here)
        # We have to use torch.matmul since we work with batches!
        out = torch.matmul(Q, K.transpose(1, 2)) # => shape: (B, L, L)

        # Divide by scaling factor
        out = out / (Q.shape[-1] ** 0.5)

        # Optional: src_mask/tgt_mask (shape: (L, L); mask values are represented by -inf)
        if mask is not None:
            out += mask.unsqueeze(0) # Broadcast since it's the same mask for all samples in batch
        
        # Push throught softmax layer
        out = f.softmax(out, dim=-1)
        
        # Optional: Dropout
        if dropout is not None:
            out = nn.Dropout(out, dropout)
        
        # Multiply with values V
        out = torch.matmul(out, V)
        
        return out

So far so good…at least I like to think. However, my problem is not the mask to address the padding (e.g. src_key_padding_mask). From different tutorials using the nn.Transformer, this mask can be generated as followes:

pad_token_index = 0

src_key_padding_mask = (source_batch != pad_token_index)

print(src_key_padding_mask)

yielding:

tensor([[ True,  True,  True, False, False, False],
        [ True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True, False]])

having shape of (N,L) which again matches the doc.

What I’m now missing is: How do I have to incorporate this matrix into my implementation of Attention?

2 Likes