Extracting transformer attention output weights

Hi there,
first of all, thanks for the nice modular implementation of the transformer architecture in pytorch!

I recently tried to extract the attention output weights for some layers in a TransformerEncoder for an input sample and have some suggestions on how to improve this.

  • The option to get the attention weights is currently only present in MultiheadAttention with the need_weights argument. This returns already the mean attention weights instead of the attention weights per head.
    First, it would be better to return the attention weights per head and let the user do the sum over heads if needed (some users want to see the attention weights per head).
    Second, it would be good to introduce the same argument to TransformerEncoderLayer and TransformerDecoderLayer in their forward methods and use that argument when evaluating self.self_attn and self.multihead_attn (currently, the default value True is used, which calculates the attention weights but then they are discarded, which is clearly not optimal)

  • There should be a way to get the forward keyword arguments in forward hooks! Then it would be simple to install a forward hook in each TransformerEncoderLayer, evaluate with needs_weights=True and store the results somewhere. For me, this is currently not possible because the attention mask is passed as keyword argument. This is a severe shortcoming of hooks! They should have access to all the arguments that forward has access to. Keyword arguments could be passed as a separate dict.

What is your opinion on this?