Hello,
I am working with the MultiheadAttention layer in PyTorch and encountered a discrepancy between using key_padding_mask and attn_mask for handling variable length sequences with padding. My goal is to ensure that padded positions do not influence the attention scores. However, the results differ when using key_padding_mask versus an equivalent attn_mask.
Here’s a minimal reproducible example:
import torch
import torch.nn as nn
# Function to create key padding mask from padded sequences
def create_key_padding_mask(embeddings, padding_value=0):
return (embeddings == padding_value).all(dim=-1)
# Function to create an attention mask from key padding mask
def create_attn_mask_from_key_padding_mask(key_padding_mask, num_heads):
batch_size, seq_length = key_padding_mask.shape
attn_mask = torch.ones((batch_size, seq_length, seq_length), dtype=torch.bool)
for i in range(batch_size):
valid_len = seq_length - key_padding_mask[i].sum().item()
attn_mask[i, :valid_len, :valid_len] = False
attn_mask = attn_mask.unsqueeze(1).expand(batch_size, num_heads, seq_length, seq_length)
attn_mask = attn_mask.reshape(batch_size * num_heads, seq_length, seq_length)
return attn_mask
# Create a batch of embeddings with variable lengths (padded with 0)
batch_size = 5
seq_length = 6
embed_dim = 100
# Simulate embeddings with padding (random example)
embeddings = torch.rand(batch_size, seq_length, embed_dim)
embeddings[0, 4:] = 0 # First sequence is of length 4
embeddings[1, 3:] = 0 # Second sequence is of length 3
embeddings[2, 5:] = 0 # Third sequence is of length 5
embeddings[3, 2:] = 0 # Fourth sequence is of length 2
embeddings[4, 1:] = 0 # Fifth sequence is of length 1
# Create the key padding mask
key_padding_mask = create_key_padding_mask(embeddings)
# Create the attention mask from key padding mask
num_heads = 10
attn_mask = create_attn_mask_from_key_padding_mask(key_padding_mask, num_heads)
# Define the MultiheadAttention layer
multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
# Perform multi-head attention using the key padding mask
attn_output_key_mask, attn_output_weights_key_mask = multihead_attn(
embeddings, embeddings, embeddings, key_padding_mask=key_padding_mask)
# Perform multi-head attention using the attention mask
attn_output_attn_mask, attn_output_weights_attn_mask = multihead_attn(
embeddings, embeddings, embeddings, attn_mask=attn_mask)
# Compare the outputs
print("Attention Output using Key Padding Mask:")
print(attn_output_key_mask)
print("Attention Output using Attention Mask:")
print(attn_output_attn_mask)
# Check if the outputs are similar
print("Are the outputs similar?")
print(torch.allclose(attn_output_key_mask, attn_output_attn_mask, atol=1e-6))
I suspect that when using key_padding_mask the last padded_columns (p) are correctly -inf in the final mask but the last p rows are not -inf, but zeros. So when the softmax is applied on dim=1 after, the last columns are 0 (which is correct) but the last rows are 1/(non_padded_items) because we have (non_padded_items zeros followed by p -inf’s)

Any insights or guidance on resolving this discrepancy would be greatly appreciated.
Thank you!
