Our attention module is essentially standard multi-head attention, just with arbitrarily many bias tensors. It’s approximately as follows:
a = torch.matmul(q, k)
for b in biases:
a += b
a = softmax(a, dim=-1)
o = torch.matmul(a, v.transpose(-1, -2))
where q
, k
, and, v
are [batch_dim, heads, Q/K/V, C]
query, key, and value tensors, respectively, and biases is a list of up to 2 bias tensors (e.g. a mask), all of which broadcast along at least one dimension such that their sizes are insignificant (they’re much smaller than the full [batch_dim, heads, Q, K]
attention logit matrix a
, in any case). During the backwards pass, since batch_dim
, Q and K are >> C, memory consumption is dominated by the first matrix multiplication and the softmax, which together contribute (2 * batch_dim
* heads
* Q * K * (2 bytes)) to the module total. Equivalent JAX code requires just (batch_dim
* heads
* Q * K * (2 bytes)), half as much.
torch doesn’t have anything like an in-place softmax, and for various reasons I’ve been unable to shrink the memory requirements using keops (apparently incompatible with the bias tensors) or by manually building an in-place softmax myself (breaks auto-differentiation).
Is there a (batch_dim
* heads
* Q * K * (2 bytes)) solution in torch?