Greetings,
during some testing with MultiheadAttention
, I required gradient calculation on the attention weights (or scores), but I encountered a problem.
My goal was to obtain the gradients of the attention weights used during the attention operation.
I have a layer of MultiheadAttention, and I perform the forward operation using need_weights=True
and average_weights=True
. The expected result is that this returns both the output of the attention operation and the attention scores (weights) used during calculations, averaged over all the heads.
The problem with this is that, when I try and compute the gradient between the final output and the attention scores, i get an error stating that One of the differentiated Tensors appears to not have been used in the graph.
.
Unless I am mistaken, this should not be the case since the attention scores are used to weight the Value matrix, which is used for computing the output.
The following is a dummy example replicating my problem:
a = torch.randn(1, 4, 100)
mh_layer = nn.MultiheadAttention(100, 5, batch_first=True)
out, weights = mh_layer(a, a, a, need_weights=True, average_weights=True)
# dummy loss
loss = torch.sum(out)
mh_layer.zero_grad()
# gradient calculation
grad = torch.autograd.grad(loss, [weights], retain_graph=True)[0].detach()
RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.
I looked at the source code, and tracked the problem down to a .view
operation in the torch.nn.functional.multi_head_attention_forward
method.
I managed to “fix” my problem by changing the method behaviour and returning the attention weights before the view operation. After this change, the above code executes without errors, and I get the gradient of the attention weights. Below you can see the snip of code I modified (I only changed 3 lines).
"""
torch.nn.functional.multi_head_attention_forward
"""
# [...]
# optionally average attention weights over heads
original_attn_output_weights = attn_output_weights # modified code
# the following view operation results in attn_output_weights disconnected from the gradient graph
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
if average_attn_weights:
attn_output_weights = attn_output_weights.mean(dim=1)
if not is_batched:
# squeeze the output if input was unbatched
attn_output = attn_output.squeeze(1)
attn_output_weights = attn_output_weights.squeeze(0)
return attn_output, original_attn_output_weights # modified code
# return attn_output, attn_output_weights # original code
Of course, this fix does not reshape the output, so it’s not ideal.
Finally, my question is: was I not using the gradient operations and/or Multihead parameters correctly? Is there an easier way that does not require changing the source code?
Thank you in advance!