Getting nn.MultiHeadAttention attention weights for each head

I’m using the nn.MultiheadAttention layer (v1.1.0) with num_heads=19 and an input tensor of size [model_size,batch_size,embed_size]

Based on the original Attention is all you need paper, I understand that there should be a matrix of attention weights for each head (19 in my case), but i can’t find a way of accesing them. When doing a forward pass the returned weights have size [batch_size,model_size,model_size] instead of something like [batch_size, 19,model_size,model_size]. I’m guessing the weights returned are an average of all the heads but that isn’t specified in the docs.

Is there another way of accessing the full attention weights?

3 Likes

I think we will have to modify F.multi_head_attention_forward for that,
in the end it has,

if need_weights:
        # average attention weights over heads
        attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
        return attn_output, attn_output_weights.sum(dim=1) / num_heads
else:
        return attn_output, None

if we want full attention weights without averaging over heads, then we will have to change it to,

if need_weights:
        attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
        return attn_output, attn_output_weights
1 Like

I suspected that was going to be the answer :worried:
Should I make an issue on the repo with the feature request?

Thanks for your suggestion I will give it a try on a local fork.

There is an active issue on this, to break multi head attention into parts, as it is a bit too long.
I think they will update it soon.

Is there any progress on this?

2 Likes

For posterity: a flag to disable averaging of attention weights across heads was added in #70055. You can now pass average_attn_weights=False to get attention weights per head.