Flex attention raises unnecessary assertion error

Hello, I’m trying to accelerate Window-based Self-attention by leveraging flex attention, but I face unexpected assertion errors when the batch size is changed.
Specifically, I don’t know why, but pytorch’s compiler assigns different tensor shape pointers to tensors with the same shape.

Here’s an error message:

 File "/home/dslisleedh/miniconda3/envs/esc/lib/python3.10/site-packages/torch/_inductor/kernel/flex_attention.py", line 769, in flex_attention
    assert Bq == Bkv, "Batch dimension must match"
torch._inductor.exc.LoweringException: AssertionError: Batch dimension must match
  target: flex_attention
  args[0]: TensorBox(StorageBox(
    InputBuffer(name='arg2_1', layout=FixedLayout('cuda', torch.float32, size=[s5, 4, s3, 16], stride=[64*s3, 16*s3, 16, 1]))
  ))
  args[1]: TensorBox(StorageBox(
    InputBuffer(name='arg5_1', layout=FixedLayout('cuda', torch.float32, size=[s6, 4, s3, 16], stride=[64*s3, 16*s3, 16, 1]))
  ))
  args[2]: TensorBox(StorageBox(
    InputBuffer(name='arg8_1', layout=FixedLayout('cuda', torch.float32, size=[s8, 4, s3, 16], stride=[64*s3, 16*s3, 16, 1]))
  ))

For example, arg0, arg1, and arg2 (q, k, and v) are given different batch size pointers as s5, s6, and s8, respectively.

However, even if you use assert q.shape[0] == v.shape[0] to make sure that all batches are the same size, the error still occurs.

Here’s my code.

import torch 
import torch.nn as nn
import torch.nn.functional as F

from einops import rearrange 

from torch.nn.attention.flex_attention import flex_attention
from typing import Optional, Sequence


def apply_rpe(table: torch.Tensor, window_size: int):
    def bias_mod(score: torch.Tensor, b: int, h: int, q_idx: int, kv_idx: int):
        q_h = q_idx // window_size
        q_w = q_idx % window_size
        k_h = kv_idx // window_size
        k_w = kv_idx % window_size
        rel_h = k_h - q_h + window_size - 1
        rel_w = k_w - q_w + window_size - 1
        rel_idx = rel_h * (2 * window_size - 1) + rel_w
        return score + table[h, rel_idx]
    return bias_mod


def feat_to_win(x: torch.Tensor, window_size: Sequence[int], heads: int):
    return rearrange(
        x, 'b (qkv heads c) (h wh) (w ww) -> qkv (b h w) heads (wh ww) c',
        heads=heads, wh=window_size[0], ww=window_size[1], qkv=3
    )


def win_to_feat(x, window_size: Sequence[int], h_div: int, w_div: int):
    return rearrange(
        x, '(b h w) heads (wh ww) c -> b (heads c) (h wh) (w ww)',
        h=h_div, w=w_div, wh=window_size[0], ww=window_size[1]
    )
    

class WindowAttention(nn.Module):
    def __init__(self, dim: int, window_size: int, num_heads: int, attn_func=None, deployment=False):
        super().__init__()
        self.dim = dim
        window_size = (window_size, window_size) if isinstance(window_size, int) else window_size
        self.window_size = window_size
        self.num_heads = num_heads
        self.to_qkv = nn.Conv2d(dim, dim*3, 1, 1, 0)
        self.to_out = nn.Conv2d(dim, dim, 1, 1, 0)
        
        self.attn_func = attn_func
        self.is_deployment = deployment
        self.relative_position_bias = nn.Parameter(
            torch.randn(num_heads, (2*window_size[0]-1)*(2*window_size[1]-1)).to(torch.float32) * 0.001
        )
        if self.is_deployment:
            self.relative_position_bias = self.relative_position_bias.requires_grad_(False)
            self.get_rpe = apply_rpe(self.relative_position_bias, window_size[0])
        else:
            self.rpe_idxs = self.create_table_idxs(window_size[0], num_heads)

    @staticmethod
    def create_table_idxs(window_size: int, heads: int):
        # Transposed idxs of original Swin Transformer
        # But much easier to implement and the same relative position distance anyway
        idxs_window = []
        for head in range(heads):
            for h in range(window_size**2):
                for w in range(window_size**2):
                    q_h = h // window_size
                    q_w = h % window_size
                    k_h = w // window_size
                    k_w = w % window_size
                    rel_h = k_h - q_h + window_size - 1
                    rel_w = k_w - q_w + window_size - 1
                    rel_idx = rel_h * (2 * window_size - 1) + rel_w
                    idxs_window.append((head, rel_idx))
        idxs = torch.tensor(idxs_window, dtype=torch.long, requires_grad=False)
        return idxs
        
    def pad_to_win(self, x: torch.Tensor, h: int, w: int) -> torch.Tensor:
        pad_h = (self.window_size[0] - h % self.window_size[0]) % self.window_size[0]
        pad_w = (self.window_size[1] - w % self.window_size[1]) % self.window_size[1]
        x = F.pad(x, (0, pad_w, 0, pad_h), mode='reflect')
        return x
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: input features with shape of (B, C, H, W)
        """
        _, _, h, w = x.shape
        x = self.pad_to_win(x, h, w)
        h_div, w_div = x.shape[2] // self.window_size[0], x.shape[3] // self.window_size[1]
        
        qkv = self.to_qkv(x)
        dtype = qkv.dtype
        qkv = feat_to_win(qkv, self.window_size, self.num_heads)
        q, k, v = qkv[0], qkv[1], qkv[2]
        assert q.shape[0] == k.shape[0] == v.shape[0]  # Batch size always the same
        
        if self.is_deployment:
            out = self.attn_func(q, k, v, score_mod=self.get_rpe)
        else:
            bias = self.relative_position_bias[self.rpe_idxs[:, 0], self.rpe_idxs[:, 1]]
            bias = bias.reshape(1, self.num_heads, self.window_size[0]*self.window_size[1], self.window_size[0]*self.window_size[1])
            out = self.attn_func(q, k, v, bias)
        
        out = win_to_feat(out, self.window_size, h_div, w_div)
        out = self.to_out(out.to(dtype)[:, :, :h, :w])
        return out   
    
attn_func = torch.compile(flex_attention, dynamic=True)
module = WindowAttention(64, 16, 4, attn_func, deployment=True)


with torch.no_grad():
    x = torch.randn(1, 64, 64, 64)

    module = module.cuda()
    x = x.cuda()

    out = module(x)
    print(out.shape)

    # Change the spatial size to change the batched window size
    x = torch.randn(1, 64, 256, 256).cuda()

    out = module(x)       # <<<< Error !!
    print(out.shape)

So, is it possible to prevent the compiler from assigning different size pointers to the same shape tensors?

I’m using pytorch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 pytorch-cuda=12.1

Thank you.

Nevermind, I think this is fixed in pytorch 2.6.0