Our attention module is essentially standard multi-head attention, just with arbitrarily many bias tensors. It’s approximately as follows:

```
a = torch.matmul(q, k)
for b in biases:
a += b
a = softmax(a, dim=-1)
o = torch.matmul(a, v.transpose(-1, -2))
```

where `q`

, `k`

, and, `v`

are `[batch_dim, heads, Q/K/V, C]`

query, key, and value tensors, respectively, and biases is a list of up to 2 bias tensors (e.g. a mask), all of which broadcast along at least one dimension such that their sizes are insignificant (they’re much smaller than the full `[batch_dim, heads, Q, K]`

attention logit matrix `a`

, in any case). During the backwards pass, since `batch_dim`

, Q and K are >> C, memory consumption is dominated by the first matrix multiplication and the softmax, which together contribute (**2** * `batch_dim`

* `heads`

* Q * K * (2 bytes)) to the module total. Equivalent JAX code requires just (`batch_dim`

* `heads`

* Q * K * (2 bytes)), half as much.

torch doesn’t have anything like an in-place softmax, and for various reasons I’ve been unable to shrink the memory requirements using keops (apparently incompatible with the bias tensors) or by manually building an in-place softmax myself (breaks auto-differentiation).

Is there a (`batch_dim`

* `heads`

* Q * K * (2 bytes)) solution in torch?