I would like to use the flash implementation of attention on sequences of variable length. However, i’m not sure how this can be achieved.
For example, I attempted to perform self-attention on padded sequences together with the padding mask as follows:
import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence
n_features = 8
batch_size = 2
lengths = torch.tensor([1, 2])
model = nn.TransformerEncoderLayer(d_model=n_features, nhead=1, batch_first=True, dtype=torch.float16, device='cuda')
# Define data
data = [torch.rand((lengths[i], n_features), dtype=torch.float16, device='cuda') for i in range(batch_size)]
data_padded = pad_sequence(data, batch_first=True)
pad_mask = (torch.arange(max(lengths))[None] >= lengths[:, None]).cuda()
with torch.backends.cuda.sdp_kernel(enable_math=False, enable_mem_efficient=False, enable_flash=True):
model(data_padded, src_key_padding_mask=pad_mask)
This results in the following error:
RuntimeError: No available kernel. Aborting execution.
Which is preceded by the following warning:
UserWarning: Both fused kernels do not support non-null attn_mask.
Is there another method to perform flash SDPA on variable-length sequence?