Let’s say I have some composition of skinny matrix multiplications and pointwise operations:
A = torch.rand(1024**2, 5)
B = torch.rand(5, 256**2)
C = torch.rand(256**2, 2)
result = torch.exp(A @ B) @ C
When calculated naively, the size of the matrix A @ B will be huge, and the memory utilization can be dramatically reduced by instead calculating the required matrix-vector products on-the-fly:
result = torch.zeros(1024**2, 2)
for i in range(256**2):
result += torch.exp(A @ B[:, i]) * C[i, :]
Does anyone know if there is a generic way to achieve this kind of memory-targeting optimization with skinny matrix multiplications in PyTorch, perhaps either by reformulating the computation, or by using an optimizing compiler / other external tool?
You could try using vmap from functorch (which comes packaged with the latest version of PyTorch). On the experimental namespace, there’s a ‘chunk’ version of vmap, which allows mini-batches of calculation to be computed in parallel (vmap example here: GitHub - pytorch/functorch: functorch is JAX-like composable function transforms for PyTorch.). Using the chunk version of vmap will allow you to compute mini-batches (subject to memory constraints) in parallel, which should speed up the for example above.
Unfortunately, just vmap doesn’t seem to solve the memory problem, since the main memory optimization comes from the sequential reduction in the for loop. I guess I would need a (chunked) map-reduce operation to achieve what I want. It seems that JAX can do the required optimization using JIT compilation (Reduce functionality for vmap · Discussion #9505 · google/jax · GitHub), but if I understood correctly, functorch is not supported by TorchScript yet.
The problem is that even the chunked version of vmap needs a large buffer to store the result of A @ B (1024**2 x 256**2 in my example), since the entire output of vmap is stored in memory. Instead, I need to map a function to a chunk of data, reduce the intermediate result, and then go to the next chunk.