Flash Attention with variable-length sequences

I would like to use the flash implementation of attention on sequences of variable length. However, i’m not sure how this can be achieved.

For example, I attempted to perform self-attention on padded sequences together with the padding mask as follows:

import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence

n_features = 8
batch_size = 2
lengths = torch.tensor([1, 2])

model = nn.TransformerEncoderLayer(d_model=n_features, nhead=1, batch_first=True, dtype=torch.float16, device='cuda')

# Define data
data = [torch.rand((lengths[i], n_features), dtype=torch.float16, device='cuda') for i in range(batch_size)]
data_padded = pad_sequence(data, batch_first=True)
pad_mask = (torch.arange(max(lengths))[None] >= lengths[:, None]).cuda()

with torch.backends.cuda.sdp_kernel(enable_math=False, enable_mem_efficient=False, enable_flash=True):
    model(data_padded, src_key_padding_mask=pad_mask)

This results in the following error:

RuntimeError: No available kernel. Aborting execution.

Which is preceded by the following warning:

UserWarning: Both fused kernels do not support non-null attn_mask.

Is there another method to perform flash SDPA on variable-length sequence?

It is achievable using Nested Tensors, as shown in the SDPA tutorial