Optimising bmm with triangular mask

When building autoregressive transformer models, it’s common to have an attn_mask that’s a simple triangular matrix. The core work comes down to this:

# Masked values are -inf and we add them on
attn_mask = torch.triu(torch.full((sz, sz), float('-inf')), diagonal=1)
attn = torch.bmm(q, k.transpose(-2, -1)) + attn_mask

Is there a better way to optimise this masked bmm for the specific case of this triangular mask? Nearly half of the bmm output is unused. Compared to an arbitrary mask, it seems that there should be something more optimal that can be done to exploit the structure, but I don’t know whether any existing kernels are capable of this.