How can I retrieve the output from each head of MULTIHEADATTENTION?

I learned the Multi-Head Attention mechanism from this article.
If I don’t misunderstand this article, according to my understanding of the article, if I increase the num_heads, I should receive more output.

However, I’ve observed that regardless of the num_heads I set, the output shape of MULTIHEADATTENTION remains the same, contradicting what I learned from the article.

My codes snipets:

attn_2_heads = nn.MultiheadAttention(embed_dim=20, num_heads=2)
attn_5_heads = nn.MultiheadAttention(embed_dim=20, num_heads=5)

L, S, N, E = 2, 3, 4, 20
query = torch.randn(L, N, E)
key = torch.randn(S, N, E)
value = torch.randn(S, N, E)
attn_output_2_heads, _ = attn_2_heads(query, key, value)
attn_output_5_heads, _ = attn_5_heads(query, key, value)
print(attn_output_2_heads.shape, attn_output_5_heads.shape)

The default behavior of nn.MultiheadAttention is to average the weights over the heads.

If you need the weights per head returned, set average_attn_weights = False during the forward method.

Oh, I thought attn_output_weights is attention score.
So should I use attn_output and attn_output_weights to retrieve the output of each head?
How to do that?

The actual function can be found here:

It is conditional on the choice of arguments given. But you could start there to determine how to apply the attn_output_weights to get the attn_output shape you want in your specific case.

1 Like