How do I obtain multiple heads via `multi_head_attention_forward` when using `nn.TransformerEncoder`?

Maintainership of the CrabNet repository is being transferred to me, and I’m very interested in removing the hacky workaround described in
GitHub - anthony-wang/CrabNet: Predict materials properties using only the composition information! so that I can keep the user-friendliness and include plotting functions as class methods without asking people to edit PyTorch source code.

To properly export the attention heads from the PyTorch nn.MultiheadAttention implementation within the transformer encoder layer, you will need to manually modify some of the source code of the PyTorch library.
This applies to PyTorch v1.6.0, v1.7.0, and v1.7.1 (potentially to other untested versions as well).

For this, open the file:
C:\Users\{USERNAME}\Anaconda3\envs\{ENVIRONMENT}\Lib\site-packages\torch\nn\functional.py
(where USERNAME is your Windows user name and ENVIRONMENT is your conda environment name (if you followed the steps above, then it should be crabnet))

At the end of the function defition of multi_head_attention_forward (line numbers may differ slightly):

L4011 def multi_head_attention_forward(
# ...
# ... [some lines omitted]
# ...
L4291    if need_weights:
L4292        # average attention weights over heads
L4293        attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
L4294        return attn_output, attn_output_weights.sum(dim=1) / num_heads
L4295    else:
L4296        return attn_output, None

Change the specific line

return attn_output, attn_output_weights.sum(dim=1) / num_heads

to:

return attn_output, attn_output_weights

This prevents the returning of the attention values as an average value over all heads, and instead returns each head’s attention matrix individually.
For more information see:

The relevant code in the CrabNet repository is as follows:

Encoder __init__.py

Encoder forward

Plotting using hook

The hacky workaround that I’m trying to remove is the editing of the PyTorch source code. As long as users can still install via pip or conda, and users don’t have to worry about the change mentioned above (i.e. editting the source code, or some other hacky workaround), I’m game (some of the intended audience will have limited programming experience). If no easy solutions exist as-is, then I can move this into a GitHub issue as a feature request. See [FYI] MultiheadAttention / Transformer · Issue #32590 · pytorch/pytorch · GitHub

Happy to hear any other general suggestions or criticisms (e.g. you’re going about this the wrong way). My experience with transformers is mostly limited to CrabNet, and I haven’t had much success with PyTorch hooks in the past. I’d be happy to use a different approach/module etc. as long as the basic structure and functionality (i.e. transformer architecture) can be retained.

2 Likes

I just wanna say your post has helped me a lot! Thanks!

1 Like

Hah, I was thinking through this exact problem all day and just saw your forum post! I think part of the problem lies in the way that the TransformerEncoderLayer is written. In the self-attention block, the needs_weight argument is hardcoded as False:

x = self.self_attn(x, x, x,
                   attn_mask=attn_mask,
                   key_padding_mask=key_padding_mask,
                   need_weights=False)[0]

For newer versions of Pytorch, the MultiheadAttention module has a flag in the forward pass that allows you to turn off weight averaging (average_attn_weights: bool = True). However, this doesn’t matter if you’re using the TransformerEncoderLayer because the hardcoded need_weights is just going to stop any weight return. If it wasn’t hardcoded, you could have used a default override statement. For example:

# the four defaults are: 
# key_padding_mask: Optional[Tensor] = None
# need_weights: bool = True
# attn_mask: Optional[Tensor] = None
# average_attn_weights: bool = True
nn.MultiheadAttention.forward.__defaults__ = (None, True, None, False)

Two possible solutions. So I think this leaves you with two options. One is to write your own versions of the TransformerEncoderLayer and the MultiheadAttention classes. I think this is a painful option that has the added drawback of missing out on future improvements to these two classes. Option number two is to monkey patch the multi_head_attention_forward function in functional.py. This is not elegant, but it gets around users with limiting programming experience having to manually edit the PyTorch source code.

The monkey patch solution. In the imports for train_crabnet.py, import the inspect package.

import inspect

Then implement these four lines anywhere after importing Model but somewhere before your first hook:

source = inspect.getsource(torch.nn.functional.multi_head_attention_forward)
new_source = source.replace('attn_output_weights.sum(dim=1) / num_heads', 'attn_output_weights')
new_source = new_source.replace('return attn_output, None', 'return attn_output, attn_output_weights')
exec(new_source, torch.nn.functional.__dict__)
  • Line 1 grabs the source code for the multi_head_attention_forward function in functional.py.
  • Line 2 finds the line where attention head averaging occurs and replaces it with no averaging.
  • Line 3 finds the return line (which is separate in later versions of Pytorch) and replaces it with a return that includes the attention weights. Note: This line might not be necessary in earlier versions of Pytorch, but it is necessary in the version I’m currently using (1.12.0). It should be harmless if it isn’t necessary.
  • Line 4 re-executes the source code with the changes from lines 2-3.

After making these changes your save_output.output will return a list with length equal to the number of batches. Each entry in the list (e.g. save_output.output[0]) is a tuple of length 2, where the second entry is your non-averaged attention weights. For example, I currently have 4x heads, batch sizes of 1024, and 5 features:

image

You can see that the first dimension is 4096 (batch_size * heads). Whereas the second and third dimension represent the square attention weight map (5 x 5).

With the monkey patch commented out, the TransformerEncoderLayer ensures that None is returned in the second entry of the tuple.

1 Like