Creating the FlexAttention BlockMask from a mask

Is it possible to use an existing mask to create the BlockMask? Our masks are quite complex so it’s a bother to have to recreate a mask_mod() function when we already have a function that creates a boolean mask tensor (I know the outputs are not the same but still).

Looking into the code of torch/nn/attention/flex_attention.py I see a _create_block_mask_inner() which calls _convert_mask_to_block_mask(), which seems perfect.

But I see later _create_sparse_block_from_block_mask() is called with mask_mod as an argument. However it’s not clear to me whether mask_mod is actually used here.

So my Q is can one bypass having to create a mask_mod() function or not?

@ezyang

1 Like

I think this code may show a path forward though I don’t fully understand it yet: modded-nanogpt/train_gpt.py at master · KellerJordan/modded-nanogpt · GitHub

The easiest way to do this is to make a mask_mod that loads from an existing mask :slight_smile:

For example,

existing_mask_tensor: Tensor
def custom_mask_mod(b, h, q_idx, kv_idx):
    return existing_mask_tensor[q_idx, kv_idx]

This’ll allow it to take advantage of block sparsity in the existing mask, although you’ll still have to:

  1. Have the existing mask tensor in memory
  2. Do loads from the existing mask tensor in cases where you have a partial mask.
1 Like

Hm I thought I tried this and got some low-level error leading to me thinking it was not viable. I’ll try again. Thank you for the reply!

Realized my lower level error was bcz of not using a sequence length % 128. But now I run into

torch._dynamo.exc.ObservedException: raised exception ExceptionVariable()

:confused:

that was smth silly. ran into this issue now, looks like one needs a newer PT: [FlexAttention] Using FlexAttention with DDP complains about a "higher order optimizer" · Issue #137481 · pytorch/pytorch · GitHub