I’m confused how attn_output_weights
is specified to have shape (N, L, S)
regardless of the number of heads. Wouldn’t there be a unique set of weights for each head?
https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html
eqy
June 13, 2021, 9:09pm
2
This seems to be because the attention weights are averaged across all of the heads:
opened 01:18PM - 10 Mar 20 UTC
enhancement
module: nn
oncall: transformer/mha
triaged
## 🚀 Feature
## Motivation
Currently when using the `nn.MultiHeadAttention` … layer, the `attn_output_weights` consists of an average of the attention weights of each head, therefore the original weights are inaccessible. That makes analysis like the one made in this [paper](https://arxiv.org/abs/1906.04341v1) very difficult.
## Pitch
When the `nn.MultiHeadAttention` forward is called with `need_weights=True` (and maybe a second parameter like `nead_attn_heads=True`), `attn_output_weights` should be a tensor of size `[N,num_heads,L,S]`,with the weights of each head, instead of the average of size `[N,L,S]` (following the notation in the [docs](https://pytorch.org/docs/stable/nn.html#multiheadattention))
## Alternatives
## Additional context
A small discussion about this subject with a potential solution was made [here](https://discuss.pytorch.org/t/getting-nn-multiheadattention-attention-weights-for-each-head/72195)
If you guys agree, I'll gladly make a PR.