Flash attention with padding mask or nested tensors

Hi everyone,
I’m trying to find out how to use flash attention for large sequences of variable length in training. Flash attention currently doesn’t support (padding) masks.
People suggested nested tensors but those seem to only work in evaluation with flash attention. Then there’s a possibility to manually set key/query/value elements to -inf or 0, imitating padding. Are there any other options for flash attention for variable length sequences?

Were you able to figure it out? Also facing the same issue

Could you share the code snippet where you are calling the scaled_dot_product_attention? Or are you using the default MultiheadAttention?