When building autoregressive transformer models, it’s common to have an
attn_mask that’s a simple triangular matrix. The core work comes down to this:
# Masked values are -inf and we add them on attn_mask = torch.triu(torch.full((sz, sz), float('-inf')), diagonal=1) attn = torch.bmm(q, k.transpose(-2, -1)) + attn_mask
Is there a better way to optimise this masked
bmm for the specific case of this triangular mask? Nearly half of the
bmm output is unused. Compared to an arbitrary mask, it seems that there should be something more optimal that can be done to exploit the structure, but I don’t know whether any existing kernels are capable of this.