Maintainership of the CrabNet repository is being transferred to me, and I’m very interested in removing the hacky workaround described in
GitHub - anthony-wang/CrabNet: Predict materials properties using only the composition information! so that I can keep the user-friendliness and include plotting functions as class methods without asking people to edit PyTorch source code.
To properly export the attention heads from the PyTorch
nn.MultiheadAttention
implementation within the transformer encoder layer, you will need to manually modify some of the source code of the PyTorch library.
This applies to PyTorch v1.6.0, v1.7.0, and v1.7.1 (potentially to other untested versions as well).For this, open the file:
C:\Users\{USERNAME}\Anaconda3\envs\{ENVIRONMENT}\Lib\site-packages\torch\nn\functional.py
(whereUSERNAME
is your Windows user name andENVIRONMENT
is your conda environment name (if you followed the steps above, then it should becrabnet
))At the end of the function defition of
multi_head_attention_forward
(line numbers may differ slightly):L4011 def multi_head_attention_forward( # ... # ... [some lines omitted] # ... L4291 if need_weights: L4292 # average attention weights over heads L4293 attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) L4294 return attn_output, attn_output_weights.sum(dim=1) / num_heads L4295 else: L4296 return attn_output, None
Change the specific line
return attn_output, attn_output_weights.sum(dim=1) / num_heads
to:
return attn_output, attn_output_weights
This prevents the returning of the attention values as an average value over all heads, and instead returns each head’s attention matrix individually.
For more information see:
The relevant code in the CrabNet repository is as follows:
Encoder __init__.py
Encoder forward
Plotting using hook
The hacky workaround that I’m trying to remove is the editing of the PyTorch source code. As long as users can still install via pip
or conda
, and users don’t have to worry about the change mentioned above (i.e. editting the source code, or some other hacky workaround), I’m game (some of the intended audience will have limited programming experience). If no easy solutions exist as-is, then I can move this into a GitHub issue as a feature request. See [FYI] MultiheadAttention / Transformer · Issue #32590 · pytorch/pytorch · GitHub
Happy to hear any other general suggestions or criticisms (e.g. you’re going about this the wrong way). My experience with transformers is mostly limited to CrabNet, and I haven’t had much success with PyTorch hooks in the past. I’d be happy to use a different approach/module etc. as long as the basic structure and functionality (i.e. transformer architecture) can be retained.