How do I obtain multiple heads via `multi_head_attention_forward` when using `nn.TransformerEncoder`?

Maintainership of the CrabNet repository is being transferred to me, and I’m very interested in removing the hacky workaround described in
GitHub - anthony-wang/CrabNet: Predict materials properties using only the composition information! so that I can keep the user-friendliness and include plotting functions as class methods without asking people to edit PyTorch source code.

To properly export the attention heads from the PyTorch nn.MultiheadAttention implementation within the transformer encoder layer, you will need to manually modify some of the source code of the PyTorch library.
This applies to PyTorch v1.6.0, v1.7.0, and v1.7.1 (potentially to other untested versions as well).

For this, open the file:
C:\Users\{USERNAME}\Anaconda3\envs\{ENVIRONMENT}\Lib\site-packages\torch\nn\functional.py
(where USERNAME is your Windows user name and ENVIRONMENT is your conda environment name (if you followed the steps above, then it should be crabnet))

At the end of the function defition of multi_head_attention_forward (line numbers may differ slightly):

L4011 def multi_head_attention_forward(
# ...
# ... [some lines omitted]
# ...
L4291    if need_weights:
L4292        # average attention weights over heads
L4293        attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
L4294        return attn_output, attn_output_weights.sum(dim=1) / num_heads
L4295    else:
L4296        return attn_output, None

Change the specific line

return attn_output, attn_output_weights.sum(dim=1) / num_heads

to:

return attn_output, attn_output_weights

This prevents the returning of the attention values as an average value over all heads, and instead returns each head’s attention matrix individually.
For more information see:

The relevant code in the CrabNet repository is as follows:

Encoder __init__.py

Encoder forward

Plotting using hook

The hacky workaround that I’m trying to remove is the editing of the PyTorch source code. As long as users can still install via pip or conda, and users don’t have to worry about the change mentioned above (i.e. editting the source code, or some other hacky workaround), I’m game (some of the intended audience will have limited programming experience). If no easy solutions exist as-is, then I can move this into a GitHub issue as a feature request. See [FYI] MultiheadAttention / Transformer · Issue #32590 · pytorch/pytorch · GitHub

Happy to hear any other general suggestions or criticisms (e.g. you’re going about this the wrong way). My experience with transformers is mostly limited to CrabNet, and I haven’t had much success with PyTorch hooks in the past. I’d be happy to use a different approach/module etc. as long as the basic structure and functionality (i.e. transformer architecture) can be retained.

1 Like

I just wanna say your post has helped me a lot! Thanks!