Different attention mask for each example in a batch

Hi all,

I would like to apply a different attention mask in each example in a batch. I’ve seen in the documentation that I can achieve this by specifying a 3D attention mask of shape (N x num_heads, L, S), where N is the batch size, but I’m not sure about the following:

assume for simplicity that I’m working with a batch size of 2 and 3 attention heads, should the attention matrix be:

  • first repeated for all attention heads for batch example 1 and then repeated for all attention heads of batch example 2, or
  • first repeated for all batch examples for attention head 1, then all batch examples for attention head 2, then all batch examples for attention head 3.

Answering my question in case someone else needs it in the future:
I did a dummy experiment and I believe that the attention mask must be provided as mention in the first bullet point, so for my example with a batch size of 2 and 3 attention heads, I think that the attention mask must be structured as follows:

  • 1 batch example, 1st attention head
  • 1 batch example, 2nd attention head
  • 1 batch example, 3rd attention head
  • 2 batch example, 1st attention head
  • 2 batch example, 2nd attention head
  • 2 batch example, 3rd attention head
    etc…

I figured this out by example, here is some minimal code to replicate the behavior:

import torch
import torch.nn as nn
N = 2
L = 3
embd = 4
num_heads = 2
S = 3

multihead_attn = nn.MultiheadAttention(embd, num_heads, batch_first=False)
query = torch.ones(L, N, embd)
key = torch.ones(S, N, embd)
value = torch.ones(S, N, embd)

### Case 1: nothing is masked ###
same_attn_mask_zeros = torch.zeros(L, S).to(bool)
print("=" * 80)
attn_output, _ = multihead_attn(query, key, value, attn_mask=same_attn_mask_zeros)
print(f"same mask all false: \nfirst example: \n{attn_output[:, 0, :]} \nsecond example: \n{attn_output[:, 1, :]}")

### Case 2: everything is masked ###
same_attn_mask_ones = torch.ones(L, S).to(bool)
attn_output, _ = multihead_attn(query, key, value, attn_mask=same_attn_mask_ones)
print("=" * 80)
print(f"same mask all true: \nfirst example: \n{attn_output[:, 0, :]} \nsecond example: \n{attn_output[:, 1, :]}")

# Assume the first batch example has nothing masked (False everywhere) -> will have random values as output
# and the second batch example has everything masked (True everywhere) -> will have nan as output

### First alternative -> repeat (first iterate over heads, then batch examples) ###
# 1 batch example, 1 head
# 2 batch example, 1 head
# 1 batch example, 2 head
# 2 batch example, 2 head
combined = torch.stack((same_attn_mask_zeros, same_attn_mask_ones), dim=0)
assert combined.shape == (2, 3, 3) # batch, L, N
repeat = combined.repeat(num_heads, 1, 1)
assert repeat.shape == (4, 3, 3) # N x num_heads, L, S
attn_output, _ = multihead_attn(query, key, value, attn_mask=repeat)
print("=" * 80)
print(f"iterate first over heads: \nfirst example: \n{attn_output[:, 0, :]} \nsecond example: \n{attn_output[:, 1, :]}")

### Second alternative -> repeat (first iterate over batch examples, then batch heads) ###
# 1 batch example, 1 head
# 1 batch example, 2 head
# 2 batch example, 1 head
# 2 batch example, 2 head

interleave = torch.repeat_interleave(combined, num_heads, dim=0)
assert interleave.shape == (4, 3, 3) # N x num_heads, L, S
attn_output, _ = multihead_attn(query, key, value, attn_mask=interleave)
print("=" * 80)
print(f"iterate first over batch examples: \nfirst example: \n{attn_output[:, 0, :]} \nsecond example: \n{attn_output[:, 1, :]}")