Masking queries in torch. MultiheadAttention

Hi, I am trying to use torch. MultiheadAttention for the following use case:

I have documents of |Q| queries, and sentences of length |K| (here, K==V). I would like for each Q to attend to all K, and ultimately, I will combine the |Q| context vectors. If I am batching these inputs, I understand that I can pass key_padding_mask=|B|x|K| where |B| is my batch size, to the model to mask the appropriate keys in each batch. I also understand that I can use attn_mask=|Q|x|K| to mask which k a given output (in my case I want |Q| outputs) attends to.

However, because I have a variable length |Q|, how can I mask this batchwise? Can I simply pass attn_mask=|B|x|Q|x|K| and get the desired effect? Essentially, certain entire distributions over |K| are invalid because that q is masked. I guess I can ust 0 out those vectors after the computation? Is my understanding even correct, or is the implementation of MultiheadAttention not valid for my use-case?

It’s actually pretty surprising to me that you cannot mask the query with this function. Am I not using it as intended?

Nevermind, this does exist, I was just getting an error because your first dimension for attn_mask needs to be batch_size * num_heads. I just had batch_size. This is explained in F.multi_head_attention_forward, but not in the MultiheadAttention forward method.

Note: this feature was merged on Jan 23 - https://github.com/pytorch/pytorch/pull/31996 - and is not in the pip distribution of pytorch

EDIT: Note that this is available in the most recent pip release, 1.5: https://github.com/pytorch/pytorch/releases