Discrepancy Between key_padding_mask and attn_mask in MultiheadAttention Layer

Hello,

I am working with the MultiheadAttention layer in PyTorch and encountered a discrepancy between using key_padding_mask and attn_mask for handling variable length sequences with padding. My goal is to ensure that padded positions do not influence the attention scores. However, the results differ when using key_padding_mask versus an equivalent attn_mask.

Here’s a minimal reproducible example:

import torch
import torch.nn as nn

# Function to create key padding mask from padded sequences
def create_key_padding_mask(embeddings, padding_value=0):
    return (embeddings == padding_value).all(dim=-1)

# Function to create an attention mask from key padding mask
def create_attn_mask_from_key_padding_mask(key_padding_mask, num_heads):
    batch_size, seq_length = key_padding_mask.shape
    attn_mask = torch.ones((batch_size, seq_length, seq_length), dtype=torch.bool)
    for i in range(batch_size):
        valid_len = seq_length - key_padding_mask[i].sum().item()
        attn_mask[i, :valid_len, :valid_len] = False
    
    attn_mask = attn_mask.unsqueeze(1).expand(batch_size, num_heads, seq_length, seq_length)
    attn_mask = attn_mask.reshape(batch_size * num_heads, seq_length, seq_length)
    return attn_mask

# Create a batch of embeddings with variable lengths (padded with 0)
batch_size = 5
seq_length = 6
embed_dim = 100

# Simulate embeddings with padding (random example)
embeddings = torch.rand(batch_size, seq_length, embed_dim)
embeddings[0, 4:] = 0  # First sequence is of length 4
embeddings[1, 3:] = 0  # Second sequence is of length 3
embeddings[2, 5:] = 0  # Third sequence is of length 5
embeddings[3, 2:] = 0  # Fourth sequence is of length 2
embeddings[4, 1:] = 0  # Fifth sequence is of length 1

# Create the key padding mask
key_padding_mask = create_key_padding_mask(embeddings)

# Create the attention mask from key padding mask
num_heads = 10
attn_mask = create_attn_mask_from_key_padding_mask(key_padding_mask, num_heads)

# Define the MultiheadAttention layer
multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)

# Perform multi-head attention using the key padding mask
attn_output_key_mask, attn_output_weights_key_mask = multihead_attn(
    embeddings, embeddings, embeddings, key_padding_mask=key_padding_mask)

# Perform multi-head attention using the attention mask
attn_output_attn_mask, attn_output_weights_attn_mask = multihead_attn(
    embeddings, embeddings, embeddings, attn_mask=attn_mask)

# Compare the outputs
print("Attention Output using Key Padding Mask:")
print(attn_output_key_mask)

print("Attention Output using Attention Mask:")
print(attn_output_attn_mask)

# Check if the outputs are similar
print("Are the outputs similar?")
print(torch.allclose(attn_output_key_mask, attn_output_attn_mask, atol=1e-6))

I suspect that when using key_padding_mask the last padded_columns (p) are correctly -inf in the final mask but the last p rows are not -inf, but zeros. So when the softmax is applied on dim=1 after, the last columns are 0 (which is correct) but the last rows are 1/(non_padded_items) because we have (non_padded_items zeros followed by p -inf’s)
disc

Any insights or guidance on resolving this discrepancy would be greatly appreciated.

Thank you!

I think(!) that it actually does not matter what the values for the padding entries are since any follow-up layer should ignore those values anyway. At least this was my take-away message when I asked about this as well a while ago.

But I would be curious if anyone can confirm or clarify this.

Hello Chris,

Thank you for your answer. I see in the multi_head_attention_forward implementation that the output directly goes into the linear layer without using padding. So, the final 𝑝 rows of the result (which is the first p rows of the input, averaged, due to the uniform attention) pass through the network.

Can someone confirm if something happens under the hood that I might be missing and that solves this? Also, in the TransformerEncoderLayer, the output of _sa_block passes to _ff_block without padding (I had to create a hook to pad it manually). Is that right?

Hello Mel,

Thanks for your response. The correct mask is the one I was creating, which also has the last two rows set to -inf instead of 0. If you look closely at your output, you’ll see that the last 2 (=number of padded items) rows are the same and equal to the mean of the first (non-padded) rows. I was asking if the last two rows are handled internally, because they shouldn’t pass through the linear layer without being zeroed first. But if you see the multi_head_attention_forward code they pass without first zeroed

ahh yeah nice catch. For me, It seems to be a bug, the padded tokens should be set again to zeros after the forward so that it wont affect the further parameters especially if you have multiple self-attention layer or a feed-forward layer. The current workaround is just to set the padding tokens to zeroes after the forward.Good Observation @evaggelos :slight_smile: Regards

Looking at the code it was then projected to different set of parameters.

Exactly! This shouldn’t be happening without first zeroing out, right?

Hello, do we have any updates on this topic? @ptrblck, I apologize for the random tag, but I’ve noticed that you are very active and helpful. I was wondering if you might have any information on this subject.

Hi @evaggelos,
The function create_attn_mask_from_key_padding_mask is not doing the expected attn_mask.
You can compare with how they do at https://github.com/pytorch/pytorch/blob/main/torch/nn/functional.py#L6179
where they first create the dimensions, expand, reshape, and later broadcast towards attn_mask:

    # merge key padding and attention masks
    if key_padding_mask is not None:
        assert key_padding_mask.shape == (
            bsz,
            src_len,
        ), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
        key_padding_mask = (
            key_padding_mask.view(bsz, 1, 1, src_len)
            .expand(-1, num_heads, -1, -1)
            .reshape(bsz * num_heads, 1, src_len)
        )
        if attn_mask is None:
            attn_mask = key_padding_mask
        else:
            attn_mask = attn_mask + key_padding_mask

So, you can come up with a function like:

def pad_to_attn_mask(pad_mask, num_heads):
    B, L, H = *pad_mask.shape, num_heads
    return pad_mask.view(B, 1, 1, L).expand(-1, H, L, -1).reshape(B * H, L, L)

And compare attention masks with a small example:

num_heads = 2
key_padding_mask_small = torch.tensor([[0, 0, 0, 1, 1], [0, 0, 0, 0, 1]], dtype=bool)

attn_mask = pad_to_attn_mask(key_padding_mask_small, num_heads)
attn_mask_ = create_attn_mask_from_key_padding_mask(key_padding_mask_small, num_heads)

which gives:

attn_mask
tensor([[[False, False, False,  True,  True],
         [False, False, False,  True,  True],
         [False, False, False,  True,  True],
         [False, False, False,  True,  True],
         [False, False, False,  True,  True]],

        [[False, False, False,  True,  True],
         [False, False, False,  True,  True],
         [False, False, False,  True,  True],
         [False, False, False,  True,  True],
         [False, False, False,  True,  True]],

        [[False, False, False, False,  True],
         [False, False, False, False,  True],
         [False, False, False, False,  True],
         [False, False, False, False,  True],
         [False, False, False, False,  True]],

        [[False, False, False, False,  True],
         [False, False, False, False,  True],
         [False, False, False, False,  True],
         [False, False, False, False,  True],
         [False, False, False, False,  True]]])

and

attn_mask_
tensor([[[False, False, False,  True,  True],
         [False, False, False,  True,  True],
         [False, False, False,  True,  True],
         [ True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True]],

        [[False, False, False,  True,  True],
         [False, False, False,  True,  True],
         [False, False, False,  True,  True],
         [ True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True]],

        [[False, False, False, False,  True],
         [False, False, False, False,  True],
         [False, False, False, False,  True],
         [False, False, False, False,  True],
         [ True,  True,  True,  True,  True]],

        [[False, False, False, False,  True],
         [False, False, False, False,  True],
         [False, False, False, False,  True],
         [False, False, False, False,  True],
         [ True,  True,  True,  True,  True]]])

which are different.
The former is what we expect as explained here: https://gmongaras.medium.com/how-do-self-attention-masks-work-72ed9382510f

I just came up on this issue this morning, while trying to use the SDPA on a custom transformer and digged.
Thanks for posting!
Sincerely,
Gabriel

Hello Gabriel,

Thank you for your response, and I apologize for the delayed reply. The article was great, and I finally understood why it’s called key_padding_mask. The purpose is to prevent the non-<PAD> tokens from communicating with the <PAD> tokens, while still allowing the <PAD> tokens to have a representation. The <PAD> token communicates with the previous tokens only in its own row/representation.