Memory parity with JAX attention

Our attention module is essentially standard multi-head attention, just with arbitrarily many bias tensors. It’s approximately as follows:

a = torch.matmul(q, k)

for b in biases:
    a += b

a = softmax(a, dim=-1)

o = torch.matmul(a, v.transpose(-1, -2))

where q, k, and, v are [batch_dim, heads, Q/K/V, C] query, key, and value tensors, respectively, and biases is a list of up to 2 bias tensors (e.g. a mask), all of which broadcast along at least one dimension such that their sizes are insignificant (they’re much smaller than the full [batch_dim, heads, Q, K] attention logit matrix a, in any case). During the backwards pass, since batch_dim, Q and K are >> C, memory consumption is dominated by the first matrix multiplication and the softmax, which together contribute (2 * batch_dim * heads * Q * K * (2 bytes)) to the module total. Equivalent JAX code requires just (batch_dim * heads * Q * K * (2 bytes)), half as much.

torch doesn’t have anything like an in-place softmax, and for various reasons I’ve been unable to shrink the memory requirements using keops (apparently incompatible with the bias tensors) or by manually building an in-place softmax myself (breaks auto-differentiation).

Is there a (batch_dim * heads * Q * K * (2 bytes)) solution in torch?