Answering my question in case someone else needs it in the future:
I did a dummy experiment and I believe that the attention mask must be provided as mention in the first bullet point, so for my example with a batch size of 2 and 3 attention heads, I think that the attention mask must be structured as follows:
- 1 batch example, 1st attention head
- 1 batch example, 2nd attention head
- 1 batch example, 3rd attention head
- 2 batch example, 1st attention head
- 2 batch example, 2nd attention head
- 2 batch example, 3rd attention head
etc…
I figured this out by example, here is some minimal code to replicate the behavior:
import torch
import torch.nn as nn
N = 2
L = 3
embd = 4
num_heads = 2
S = 3
multihead_attn = nn.MultiheadAttention(embd, num_heads, batch_first=False)
query = torch.ones(L, N, embd)
key = torch.ones(S, N, embd)
value = torch.ones(S, N, embd)
### Case 1: nothing is masked ###
same_attn_mask_zeros = torch.zeros(L, S).to(bool)
print("=" * 80)
attn_output, _ = multihead_attn(query, key, value, attn_mask=same_attn_mask_zeros)
print(f"same mask all false: \nfirst example: \n{attn_output[:, 0, :]} \nsecond example: \n{attn_output[:, 1, :]}")
### Case 2: everything is masked ###
same_attn_mask_ones = torch.ones(L, S).to(bool)
attn_output, _ = multihead_attn(query, key, value, attn_mask=same_attn_mask_ones)
print("=" * 80)
print(f"same mask all true: \nfirst example: \n{attn_output[:, 0, :]} \nsecond example: \n{attn_output[:, 1, :]}")
# Assume the first batch example has nothing masked (False everywhere) -> will have random values as output
# and the second batch example has everything masked (True everywhere) -> will have nan as output
### First alternative -> repeat (first iterate over heads, then batch examples) ###
# 1 batch example, 1 head
# 2 batch example, 1 head
# 1 batch example, 2 head
# 2 batch example, 2 head
combined = torch.stack((same_attn_mask_zeros, same_attn_mask_ones), dim=0)
assert combined.shape == (2, 3, 3) # batch, L, N
repeat = combined.repeat(num_heads, 1, 1)
assert repeat.shape == (4, 3, 3) # N x num_heads, L, S
attn_output, _ = multihead_attn(query, key, value, attn_mask=repeat)
print("=" * 80)
print(f"iterate first over heads: \nfirst example: \n{attn_output[:, 0, :]} \nsecond example: \n{attn_output[:, 1, :]}")
### Second alternative -> repeat (first iterate over batch examples, then batch heads) ###
# 1 batch example, 1 head
# 1 batch example, 2 head
# 2 batch example, 1 head
# 2 batch example, 2 head
interleave = torch.repeat_interleave(combined, num_heads, dim=0)
assert interleave.shape == (4, 3, 3) # N x num_heads, L, S
attn_output, _ = multihead_attn(query, key, value, attn_mask=interleave)
print("=" * 80)
print(f"iterate first over batch examples: \nfirst example: \n{attn_output[:, 0, :]} \nsecond example: \n{attn_output[:, 1, :]}")