Question about attn_mask in torch.nn.MultiheadAttention

Hi,
I’m trying to pass a set of queries, keys and values through a Multi head attention layer but have some of the queries attend to just part of the keys. I think I need to use an attention mask but not quite sure from the documentation how I should use it.
Let’s say I have 3 queries, keys and values and I want:

  1. the first query to attend just to the first and third keys
  2. the second query to attend to all keys
  3. the third query to attend just to the first and third keys
    would my attention mask then need to look like this?
    [[False,True,False],
    [False,False,False],
    [False,True,False]]
    Thanks