When building autoregressive transformer models, it’s common to have an attn_mask
that’s a simple triangular matrix. The core work comes down to this:
# Masked values are -inf and we add them on
attn_mask = torch.triu(torch.full((sz, sz), float('-inf')), diagonal=1)
attn = torch.bmm(q, k.transpose(-2, -1)) + attn_mask
Is there a better way to optimise this masked bmm
for the specific case of this triangular mask? Nearly half of the bmm
output is unused. Compared to an arbitrary mask, it seems that there should be something more optimal that can be done to exploit the structure, but I don’t know whether any existing kernels are capable of this.