Hi. Does any backend of SDPA or FlexAttention support both features at once:
- varlen packed inputs/outputs
- custom block-diagonal mask: some blocks full, some blocks causal
I know that SDPA has some (broken? bad numerics for variable length attention with cudnn · Issue #169146 · pytorch/pytorch · GitHub ) support for explicit varlen (and maybe via NJT?). And that there is some support of custom block-diagonal masks via `attn_mask`. Does this support actually skip computation of empty off-diagonal blocks? But I’m not sure if there’s any SDPA backend that supports both?
If not, is the best option - manual bindings of FlashAttention? Or FlexAttention?
Thanks!