I have a question regarding the MultiheadAttention
and multi_head_attention_forward
, it seems like the padding masking is applied only to one axis instead of being applied to two axes.
When attaching hook to the attention module, I receive the attention weights as so, when I would expect instead only a non-zero square attention map while the rest being zero ?
Looking at the source code:
if key_padding_mask is not None:
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
attn_output_weights = attn_output_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
float("-inf"),
)
# **Expect Extra** =============>
attn_output_weights = attn_output_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2).transpose(-1, -2),
float("-inf"),
)
# <==================== End + Handle the lower zero-row softmax.
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
Is it something I am missing or that’s intended like so which seems weird, as attention is shared between padding and audio features?