Masked Query Gradient Flow to Keys and Values

Hi,
I was wondering why the gradient in this example does not flow to the key and value. What am I doing wrong? How can I use padded batches with different target sequence lengths?

import torch
from torch.nn.functional import scaled_dot_product_attention

k = v = torch.rand(3, 4, 8)
q = torch.rand(3, 5, 8)

q.requires_grad = True
k.requires_grad = True
v.requires_grad = True

mask = torch.ones(3, 5, 4, dtype=torch.bool)
mask[:, :, -1] = 0
mask[:, -1, :] = 0

out = scaled_dot_product_attention(q, k, v, attn_mask=mask)
torch.mean(out[:, :-1, :]).backward()

Thank you

To take the example of v, the gradient of v is softmax(q @ k + bias).T @ grad_output
softmax(q @ k + bias) produces a row with entirely nans because you masked out an entire row (implies bias of -inf), so it’s transpose has an entire column of nans.

grad_output has an entire row full of zeros, due to viewing the corresponding row of out.
You may have expected the row of zeros of the rhs to cancel out the column of nans in the lhs since they are the last row and the last column respectively.
However is not true in autograd today as zero times nan is still nan, so since the gradient of v is a linear combination of the columns of the lhs, this results the entirety of grad of v becoming nan.

One workaround for this is to specify a floating mask instead of a boolean mask, and use a large negative number instead of -inf so that softmax no longer produces nans.

1 Like

Would it be a good idea to fill the mask with ones for masked queries? As the forward pass will be performed anyways, but the masked query does not influence the output, as I mask it in the loss function anyways?

Hm I’d expect this to affect the normalization for softmax. Like if you had a large outlier, normalizing that with the rest of your row would mean your outputs are scaled down.

1 Like

Okay, I might have understood what I got wrong. In this context pertains the attention mask only the key/value pairs for specific queries. Me masking all key/value pairs for queries I would like to mask (padded ones), creates this zero row, which destroys the gradient because 0 times the nan value still influences both key and value gradients. Me ignoring the masked queries in the calculation of the loss function should fix this, as the results of the masked queries do not influence the loss function and therefore not the gradients. Thank you :slight_smile: