Can MultiheadAttention be optimized for sparse attention masks?

I need to apply MultiheadAttenction to batches of sequences whose lengths are very non-uniform, so I have to use a lot of padding + attention masks. As I understand, masking tokens out does not decrease the amount of computation being done, not the amount of data saved for the backward pass. :frowning:

Other than trying to construct batches by grouping examples of similar length together, are there any other tricks that people have come up with to deal with this problem?