I need to apply MultiheadAttenction to batches of sequences whose lengths are very non-uniform, so I have to use a lot of padding + attention masks. As I understand, masking tokens out does not decrease the amount of computation being done, not the amount of data saved for the backward pass.
Other than trying to construct batches by grouping examples of similar length together, are there any other tricks that people have come up with to deal with this problem?