Hello, I’m trying to accelerate Window-based Self-attention by leveraging flex attention, but I face unexpected assertion errors when the batch size is changed.
Specifically, I don’t know why, but pytorch’s compiler assigns different tensor shape pointers to tensors with the same shape.
Here’s an error message:
File "/home/dslisleedh/miniconda3/envs/esc/lib/python3.10/site-packages/torch/_inductor/kernel/flex_attention.py", line 769, in flex_attention
assert Bq == Bkv, "Batch dimension must match"
torch._inductor.exc.LoweringException: AssertionError: Batch dimension must match
target: flex_attention
args[0]: TensorBox(StorageBox(
InputBuffer(name='arg2_1', layout=FixedLayout('cuda', torch.float32, size=[s5, 4, s3, 16], stride=[64*s3, 16*s3, 16, 1]))
))
args[1]: TensorBox(StorageBox(
InputBuffer(name='arg5_1', layout=FixedLayout('cuda', torch.float32, size=[s6, 4, s3, 16], stride=[64*s3, 16*s3, 16, 1]))
))
args[2]: TensorBox(StorageBox(
InputBuffer(name='arg8_1', layout=FixedLayout('cuda', torch.float32, size=[s8, 4, s3, 16], stride=[64*s3, 16*s3, 16, 1]))
))
For example, arg0, arg1, and arg2 (q, k, and v) are given different batch size pointers as s5, s6, and s8, respectively.
However, even if you use assert q.shape[0] == v.shape[0]
to make sure that all batches are the same size, the error still occurs.
Here’s my code.
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch.nn.attention.flex_attention import flex_attention
from typing import Optional, Sequence
def apply_rpe(table: torch.Tensor, window_size: int):
def bias_mod(score: torch.Tensor, b: int, h: int, q_idx: int, kv_idx: int):
q_h = q_idx // window_size
q_w = q_idx % window_size
k_h = kv_idx // window_size
k_w = kv_idx % window_size
rel_h = k_h - q_h + window_size - 1
rel_w = k_w - q_w + window_size - 1
rel_idx = rel_h * (2 * window_size - 1) + rel_w
return score + table[h, rel_idx]
return bias_mod
def feat_to_win(x: torch.Tensor, window_size: Sequence[int], heads: int):
return rearrange(
x, 'b (qkv heads c) (h wh) (w ww) -> qkv (b h w) heads (wh ww) c',
heads=heads, wh=window_size[0], ww=window_size[1], qkv=3
)
def win_to_feat(x, window_size: Sequence[int], h_div: int, w_div: int):
return rearrange(
x, '(b h w) heads (wh ww) c -> b (heads c) (h wh) (w ww)',
h=h_div, w=w_div, wh=window_size[0], ww=window_size[1]
)
class WindowAttention(nn.Module):
def __init__(self, dim: int, window_size: int, num_heads: int, attn_func=None, deployment=False):
super().__init__()
self.dim = dim
window_size = (window_size, window_size) if isinstance(window_size, int) else window_size
self.window_size = window_size
self.num_heads = num_heads
self.to_qkv = nn.Conv2d(dim, dim*3, 1, 1, 0)
self.to_out = nn.Conv2d(dim, dim, 1, 1, 0)
self.attn_func = attn_func
self.is_deployment = deployment
self.relative_position_bias = nn.Parameter(
torch.randn(num_heads, (2*window_size[0]-1)*(2*window_size[1]-1)).to(torch.float32) * 0.001
)
if self.is_deployment:
self.relative_position_bias = self.relative_position_bias.requires_grad_(False)
self.get_rpe = apply_rpe(self.relative_position_bias, window_size[0])
else:
self.rpe_idxs = self.create_table_idxs(window_size[0], num_heads)
@staticmethod
def create_table_idxs(window_size: int, heads: int):
# Transposed idxs of original Swin Transformer
# But much easier to implement and the same relative position distance anyway
idxs_window = []
for head in range(heads):
for h in range(window_size**2):
for w in range(window_size**2):
q_h = h // window_size
q_w = h % window_size
k_h = w // window_size
k_w = w % window_size
rel_h = k_h - q_h + window_size - 1
rel_w = k_w - q_w + window_size - 1
rel_idx = rel_h * (2 * window_size - 1) + rel_w
idxs_window.append((head, rel_idx))
idxs = torch.tensor(idxs_window, dtype=torch.long, requires_grad=False)
return idxs
def pad_to_win(self, x: torch.Tensor, h: int, w: int) -> torch.Tensor:
pad_h = (self.window_size[0] - h % self.window_size[0]) % self.window_size[0]
pad_w = (self.window_size[1] - w % self.window_size[1]) % self.window_size[1]
x = F.pad(x, (0, pad_w, 0, pad_h), mode='reflect')
return x
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: input features with shape of (B, C, H, W)
"""
_, _, h, w = x.shape
x = self.pad_to_win(x, h, w)
h_div, w_div = x.shape[2] // self.window_size[0], x.shape[3] // self.window_size[1]
qkv = self.to_qkv(x)
dtype = qkv.dtype
qkv = feat_to_win(qkv, self.window_size, self.num_heads)
q, k, v = qkv[0], qkv[1], qkv[2]
assert q.shape[0] == k.shape[0] == v.shape[0] # Batch size always the same
if self.is_deployment:
out = self.attn_func(q, k, v, score_mod=self.get_rpe)
else:
bias = self.relative_position_bias[self.rpe_idxs[:, 0], self.rpe_idxs[:, 1]]
bias = bias.reshape(1, self.num_heads, self.window_size[0]*self.window_size[1], self.window_size[0]*self.window_size[1])
out = self.attn_func(q, k, v, bias)
out = win_to_feat(out, self.window_size, h_div, w_div)
out = self.to_out(out.to(dtype)[:, :, :h, :w])
return out
attn_func = torch.compile(flex_attention, dynamic=True)
module = WindowAttention(64, 16, 4, attn_func, deployment=True)
with torch.no_grad():
x = torch.randn(1, 64, 64, 64)
module = module.cuda()
x = x.cuda()
out = module(x)
print(out.shape)
# Change the spatial size to change the batched window size
x = torch.randn(1, 64, 256, 256).cuda()
out = module(x) # <<<< Error !!
print(out.shape)
So, is it possible to prevent the compiler from assigning different size pointers to the same shape tensors?
I’m using pytorch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 pytorch-cuda=12.1
Thank you.