Hi, I am trying to use torch. MultiheadAttention for the following use case:
I have documents of |Q| queries, and sentences of length |K| (here, K==V). I would like for each Q to attend to all K, and ultimately, I will combine the |Q| context vectors. If I am batching these inputs, I understand that I can pass key_padding_mask=|B|x|K|
where |B|
is my batch size, to the model to mask the appropriate keys in each batch. I also understand that I can use attn_mask=|Q|x|K|
to mask which k a given output (in my case I want |Q| outputs) attends to.
However, because I have a variable length |Q|, how can I mask this batchwise? Can I simply pass attn_mask=|B|x|Q|x|K|
and get the desired effect? Essentially, certain entire distributions over |K| are invalid because that q is masked. I guess I can ust 0 out those vectors after the computation? Is my understanding even correct, or is the implementation of MultiheadAttention not valid for my use-case?