Flex_attention: BlockMask conversion?

I am trying to implement an inference variant of scaled dot product attention where the key-value cache slots can be in any order (this is needed for sparse attention, where earlier slots can be overwritten by later ones). To get the right causal attention mask, I either need to (1) reorder the slots explicitly, or (2) define the causal attention mask implicitly via the token position vector. The latter would be faster and not need another copy of the KV cache.

I tried to implement (2) with flex_attention. Here is the code:

from functools import partial
import math
from typing import Optional, Callable

import torch
from torch.nn.attention.flex_attention import (
    flex_attention,
    create_block_mask,
)
import torch.nn.functional as F


def causal_mask_for_chunk_1d(
    batch: torch.Tensor,
    head: torch.Tensor,
    q_idx: torch.Tensor,
    kv_idx: torch.Tensor,
    input_pos: torch.Tensor,
    token_positions: torch.Tensor,
) -> torch.Tensor:
    left_arg = q_idx + input_pos
    right_arg = token_positions[kv_idx]
    return left_arg >= right_arg


def causal_mask_for_chunk_notp(
    batch: torch.Tensor,
    head: torch.Tensor,
    q_idx: torch.Tensor,
    kv_idx: torch.Tensor,
    offset: torch.Tensor,
) -> torch.Tensor:
    left_arg = q_idx + offset
    return left_arg >= kv_idx


class AttnFunctionForChunk:
    def __init__(
        self,
        q_len: int,
        kv_len: int,
        device: torch.device,
        use_tp: bool,
    ):
        kwargs = dict(device=device, dtype=torch.int32)
        self.input_pos = torch.tensor(kv_len - q_len, **kwargs)
        if use_tp:
            self.input_pos = torch.tensor(0, **kwargs)
            self.token_positions = torch.zeros((kv_len,), **kwargs)
            mask_mod = partial(
                causal_mask_for_chunk_1d,
                input_pos=self.input_pos,
                token_positions=self.token_positions,
            )
        else:
            self.input_pos = None
            self.token_positions = None
            mask_mod = partial(
                causal_mask_for_chunk_notp,
                offset=kv_len - q_len,
            )
        self.block_mask = create_block_mask(
            mask_mod,
            B=None,
            H=None,
            Q_LEN=q_len,
            KV_LEN=kv_len,
            device=device,
        )
        self.attn_fn_compiled = torch.compile(flex_attention)

    def __call__(
        self,
        input_pos: int,
        token_positions: Optional[torch.Tensor],
    ) -> Callable:
        if self.input_pos is not None:
            self.input_pos.copy_(input_pos)
            self.token_positions[:] = token_positions
        return partial(
            self.attn_fn_compiled,
            block_mask=self.block_mask,
        )


def scaled_dot_product_attention_flexatt(
    flexatt_args: AttnFunctionForChunk,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale_factor: float,
    input_pos: int,
    token_positions: torch.Tensor,
) -> torch.Tensor:
    n_head = query.shape[1]
    n_query_groups = key.shape[1]
    attn_fn = flexatt_args(
        input_pos=input_pos,
        token_positions=token_positions,
    )
    if flexatt_args.input_pos is None:
        # Sort to obtain the right attention mask
        sort_index = torch.argsort(token_positions)
        key = key[:, :, sort_index, :]
        value = value[:, :, sort_index, :]
    return attn_fn(
        query=query,
        key=key,
        value=value,
        scale=scale_factor,
        enable_gqa=n_query_groups < n_head,
    )


def attention_compute_scores(
    query: torch.Tensor,
    key: torch.Tensor,
    out: Optional[torch.Tensor] = None,
    scale_factor: float = 1.0,
) -> torch.Tensor:
    """
    Compute inner product scores (without masking). Here,
    `nh_q = q_per_kv * nh_k` with `q_per_kv >= 1`.

    Args:
        query: Query tensor, `(bs, nh_q, q_len, hs)`
        key: Key tensor, `(bs, nh_k, kv_len, hs)`
        out: Result written here, if given
        scale_factor: Scale factor for inner product scores

    Returns:
        Inner product scores, `(bs, nh_q, q_len, kv_len)`. This is `out` if given

    """
    assert query.ndim == key.ndim == 4
    assert query.shape[0] == key.shape[0] and query.shape[3] == key.shape[3]
    nh_q = query.shape[1]
    nh_k = key.shape[1]
    assert nh_q % nh_k == 0
    # - query, arg1: (bs, nh_q, q_len, hs)
    # - key: (bs, nh_k, kv_len, hs)
    # - key_transposed: (bs, nh_k, hs, kv_len)
    q_per_kv = nh_q // nh_k
    if scale_factor == 1.0:
        key_transposed = key.mT
        arg1 = query
    elif query.numel() <= key.numel():
        key_transposed = key.mT
        arg1 = query * scale_factor
    else:
        key_transposed = key.mT * scale_factor
        arg1 = query
    if q_per_kv == 1:
        out = torch.matmul(arg1, key_transposed, out=out)
    else:
        assert q_per_kv > 1
        q_shape = query.shape[:1] + (nh_k, q_per_kv) + query.shape[2:]
        _query = arg1.view(*q_shape)
        key_transposed = key_transposed.unsqueeze(2)
        # At this point:
        # - _query: (bs, nh_k, q_per_kv, q_len, hs)
        # - key_transposed: (bs, nh_k, 1, hs, kv_len)
        # - scores: (bs, nh_k, q_per_kv, q_len, kv_len)
        if out is not None:
            out = out.view(_query.shape[:-1] + (key.shape[2],))
        out = torch.matmul(_query, key_transposed, out=out)
        s_shape = query.shape[:-1] + (key.shape[2],)
        out = out.view(*s_shape)
    return out


def attention_compute_weighted_values(
    scores: torch.Tensor,
    value: torch.Tensor,
) -> torch.Tensor:
    """
     Args:
        scores: Attention weights, `(bs, nh_q, q_len, kv_len)`
        value: Value tensor, `(bs, nh_k, kv_len, hs)`

    Returns:
        Attention outputs, `(bs, nh_q, q_len, hs)`

    """
    assert scores.ndim == value.ndim == 4
    assert scores.shape[0] == scores.shape[0] and scores.shape[3] == value.shape[2]
    nh_q = scores.shape[1]
    nh_k = value.shape[1]
    assert nh_q % nh_k == 0
    # - scores: (bs, nh_q, q_len, kv_len)
    # - value: (bs, nh_k, kv_len, hs)
    q_per_kv = nh_q // nh_k
    if q_per_kv == 1:
        return torch.matmul(scores, value)
    else:
        s_shape = scores.shape[:1] + (nh_k, q_per_kv) + scores.shape[2:]
        _scores = scores.view(*s_shape)
        _value = value.unsqueeze(2)
        # At this point:
        # - _scores: (bs, nh_k, q_per_kv, q_len, kv_len)
        # - _value: (bs, nh_k, 1, kv_len, hs)
        # - result: (bs, nh_k, q_per_kv, q_len, hs)
        result = torch.matmul(_scores, _value)
        r_shape = scores.shape[:-1] + (value.shape[-1],)
        return result.view(*r_shape)


def minus_infinity(dtype: torch.dtype) -> float:
    return torch.finfo(dtype).min


def mask_slice_bool(
    input_pos: int,
    num: int,
    token_positions: torch.Tensor,
) -> torch.Tensor:
    assert token_positions.ndim == 1
    kwargs = dict(device=token_positions.device, dtype=token_positions.dtype)
    return torch.arange(
        input_pos,
        input_pos + num,
        **kwargs,
    ).view(
        -1, 1
    ) < token_positions.view(1, -1)


def build_mask_slice(
    input_pos: int,
    num: int,
    token_positions: torch.Tensor,
    batch_size: int,
    n_head: int,
    dtype: torch.dtype,
) -> torch.Tensor:
    bool_mask = mask_slice_bool(
        input_pos,
        num,
        token_positions,
    )
    mask = torch.zeros(
        bool_mask.shape,
        dtype=dtype,
        device=token_positions.device,
    )
    mask.masked_fill_(bool_mask, minus_infinity(dtype))
    return mask[None, None, :, :].expand(batch_size, n_head, -1, -1)


def eager_scaled_dot_product_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale_factor: float,
    input_pos: int,
    token_positions: torch.Tensor,
) -> torch.Tensor:
    assert input_pos > 0
    dtype = torch.float32
    batch_size, n_head, q_len, _ = query.shape
    _, n_query_groups, kv_len, _ = key.shape
    assert token_positions.shape == (kv_len,)
    query32 = query.to(dtype)
    key32 = key.to(dtype)
    attn_weights = attention_compute_scores(query32, key32) * scale_factor
    # Attention masking
    mask = build_mask_slice(
        input_pos,
        q_len,
        token_positions,
        batch_size,
        n_head,
        dtype,
    )
    attn_weights = attn_weights + mask
    attn_weights = F.softmax(attn_weights, dim=-1)
    del query32, key32
    value32 = value.to(torch.float32)
    return attention_compute_weighted_values(attn_weights, value32).to(query.dtype)


def main(
    cache_length: int,
    chunk_size: int,
    batch_size: int,
    n_head: int,
    n_query_groups: int,
    head_size: int,
    device: torch.device,
    dtype: torch.dtype,
):
    seed = 31415927
    torch.manual_seed(seed)
    scale_factor = 1.0 / math.sqrt(head_size)

    attn_outputs = []
    q_shape = (batch_size, n_head, chunk_size, head_size)
    query = torch.randn(*q_shape, device=device, dtype=dtype)
    kv_shape = (batch_size, n_query_groups, cache_length, head_size)
    key = torch.randn(*kv_shape, device=device, dtype=dtype)
    value = torch.randn(*kv_shape, device=device, dtype=dtype)
    input_pos = cache_length
    start = input_pos + chunk_size - cache_length
    token_positions = (
        torch.randperm(cache_length, device=device, dtype=torch.int64) + start
    )
    # Eager attention
    attn_outputs.append(
        eager_scaled_dot_product_attention(
            query=query,
            key=key,
            value=value,
            scale_factor=scale_factor,
            input_pos=cache_length,
            token_positions=token_positions,
        )
    )
    # FlexAttention (use_tp=False)
    flexatt_args = AttnFunctionForChunk(
        q_len=chunk_size,
        kv_len=cache_length,
        device=device,
        use_tp=False,
    )
    attn_outputs.append(
        scaled_dot_product_attention_flexatt(
            flexatt_args=flexatt_args,
            query=query,
            key=key,
            value=value,
            scale_factor=scale_factor,
            input_pos=cache_length,
            token_positions=token_positions,
        )
    )
    # FlexAttention (use_tp=True)
    flexatt_args = AttnFunctionForChunk(
        q_len=chunk_size,
        kv_len=cache_length,
        device=device,
        use_tp=True,
    )
    attn_outputs.append(
        scaled_dot_product_attention_flexatt(
            flexatt_args=flexatt_args,
            query=query,
            key=key,
            value=value,
            scale_factor=scale_factor,
            input_pos=cache_length,
            token_positions=token_positions,
        )
    )
    # Comparison
    print("Compare eager vs flex_attention (use_tp=False)")
    torch.testing.assert_close(attn_outputs[0], attn_outputs[1])
    print("Compare eager vs flex_attention (use_tp=True)")
    torch.testing.assert_close(attn_outputs[0], attn_outputs[2])


if __name__ == "__main__":
    if not torch.cuda.is_available():
        raise AssertionError("CUDA not available")
    batch_size = 2
    n_head = 32
    n_query_groups = 32
    cache_length = 4096
    head_dim = 128
    chunk_size = 512
    device = torch.device("cuda", 0)
    dtype = torch.float32
    main(
        cache_length=cache_length,
        chunk_size=chunk_size,
        batch_size=batch_size,
        n_head=n_head,
        n_query_groups=n_query_groups,
        head_size=head_dim,
        device=device,
        dtype=dtype,
    )

In essence, I need mask_mod of the form q_idx >= index[kv_idx], where index is a captured tensor, where index.ndim == 1 (and not scalar, like the normal inputs).

The code above does not work for (2)

Compare eager vs flex_attention (use_tp=False)
Compare eager vs flex_attention (use_tp=True)
Traceback (most recent call last):
  File "/home/ubuntu/sync/keys_values/keys_values/scripts/debug_flex_attention_simple.py", line 366, in <module>
    main(
  File "/home/ubuntu/sync/keys_values/keys_values/scripts/debug_flex_attention_simple.py", line 352, in main
    torch.testing.assert_close(attn_outputs[0], attn_outputs[2])
  File "/home/ubuntu/virtenvs/keysvals/lib/python3.12/site-packages/torch/testing/_comparison.py", line 1600, in assert_close
    raise error_metas[0].to_error(msg)
AssertionError: Tensor-likes are not close!

Mismatched elements: 4177459 / 4194304 (99.6%)
Greatest absolute difference: 0.10383622348308563 at index (0, 31, 65, 38) (up to 1e-05 allowed)
Greatest relative difference: 744068.1875 at index (1, 30, 449, 34) (up to 1.3e-06 allowed)

In this blog post (and their paper), the flex_attention authors present an advanced use case (paged attention), using a concept they call BlockMask conversion: https://pytorch.org/blog/flexattention-for-inference/

I have the feeling this would solve my problem, but I do not understand how it would apply.

If anybody understands flex_attention, could you give me a hint?

In fact, this is a simple version of what I really need, namely using a mask_mod of the form q_idx >= index[b, h, kv_idx] for a captured tensor with index.ndim == 3. Just I wanted to keep it simple.