MultiheadAttention: attention weights are not being used in gradient graph

Greetings,
during some testing with MultiheadAttention, I required gradient calculation on the attention weights (or scores), but I encountered a problem.

My goal was to obtain the gradients of the attention weights used during the attention operation.
I have a layer of MultiheadAttention, and I perform the forward operation using need_weights=True and average_weights=True. The expected result is that this returns both the output of the attention operation and the attention scores (weights) used during calculations, averaged over all the heads.
The problem with this is that, when I try and compute the gradient between the final output and the attention scores, i get an error stating that One of the differentiated Tensors appears to not have been used in the graph..
Unless I am mistaken, this should not be the case since the attention scores are used to weight the Value matrix, which is used for computing the output.
The following is a dummy example replicating my problem:

a = torch.randn(1, 4, 100)

mh_layer = nn.MultiheadAttention(100, 5, batch_first=True)

out, weights = mh_layer(a, a, a, need_weights=True, average_weights=True)

# dummy loss
loss = torch.sum(out)
mh_layer.zero_grad()

# gradient calculation
grad = torch.autograd.grad(loss, [weights], retain_graph=True)[0].detach()
RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.

I looked at the source code, and tracked the problem down to a .view operation in the torch.nn.functional.multi_head_attention_forward method.
I managed to “fix” my problem by changing the method behaviour and returning the attention weights before the view operation. After this change, the above code executes without errors, and I get the gradient of the attention weights. Below you can see the snip of code I modified (I only changed 3 lines).

"""
torch.nn.functional.multi_head_attention_forward
"""
# [...]
# optionally average attention weights over heads
original_attn_output_weights = attn_output_weights    # modified code

# the following view operation results in attn_output_weights disconnected from the gradient graph
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
if average_attn_weights:
    attn_output_weights = attn_output_weights.mean(dim=1)

if not is_batched:
# squeeze the output if input was unbatched
    attn_output = attn_output.squeeze(1)
    attn_output_weights = attn_output_weights.squeeze(0)
return attn_output, original_attn_output_weights    # modified code
# return attn_output, attn_output_weights    # original code

Of course, this fix does not reshape the output, so it’s not ideal.
Finally, my question is: was I not using the gradient operations and/or Multihead parameters correctly? Is there an easier way that does not require changing the source code?

Thank you in advance!

3 Likes

Your explanation makes sense and I also think the .view() operation “diverges” from the computation graph and is thus raising the error since the original tensor is replaced.
CC @albanD are you aware of a proper approach instead of manipulating the source code?

This is indeed expected I’m afraid, the result of the view is not what is part of the computation in the “main branch” and thus you don’t get any gradient for it.

I’m afraid there aren’t any great options here:

  • You can use a modified version of MHA indeed
  • You can try and suggest upstream MHA to actually use your modified version
  • You could “walk” the autograd graph up from the view back to the node that is in the graph and get the gradient for that. But this will be more brittle and a bit more involved.

Thank you both for the help in assessing the situation, and the insight on possible solutions. I think the more plausible solution might be to override the MultiheadAttention forward function to use my custom method, in order to avoid changing the source code.