Memory optimization for compositions of skinny matrix multiplications with pointwise operations

Let’s say I have some composition of skinny matrix multiplications and pointwise operations:

A = torch.rand(1024**2, 5)
B = torch.rand(5, 256**2)
C = torch.rand(256**2, 2)
result = torch.exp(A @ B) @ C

When calculated naively, the size of the matrix A @ B will be huge, and the memory utilization can be dramatically reduced by instead calculating the required matrix-vector products on-the-fly:

result = torch.zeros(1024**2, 2)
for i in range(256**2):
    result += torch.exp(A @ B[:, i]) * C[i, :]

Does anyone know if there is a generic way to achieve this kind of memory-targeting optimization with skinny matrix multiplications in PyTorch, perhaps either by reformulating the computation, or by using an optimizing compiler / other external tool?

Many thanks for any help.

Hi @askorikov,

You could try using vmap from functorch (which comes packaged with the latest version of PyTorch). On the experimental namespace, there’s a ‘chunk’ version of vmap, which allows mini-batches of calculation to be computed in parallel (vmap example here: GitHub - pytorch/functorch: functorch is JAX-like composable function transforms for PyTorch.). Using the chunk version of vmap will allow you to compute mini-batches (subject to memory constraints) in parallel, which should speed up the for example above.

Thanks for your suggestion @AlphaBetaGamma96!

Unfortunately, just vmap doesn’t seem to solve the memory problem, since the main memory optimization comes from the sequential reduction in the for loop. I guess I would need a (chunked) map-reduce operation to achieve what I want. It seems that JAX can do the required optimization using JIT compilation (Reduce functionality for vmap · Discussion #9505 · google/jax · GitHub), but if I understood correctly, functorch is not supported by TorchScript yet.

Hi @askorikov,

Functorch does have a chunked vmap, it’s not on the docs (as I shared the standard vmap docs) but it’s in there. Instead of functorch.vmap, you use functorch.experimental.chunk_vmap, more info this issue: vmap, jacrev should have an option to vectorize "chunks" instead of the entire batch at a time · Issue #680 · pytorch/functorch · GitHub

The problem is that even the chunked version of vmap needs a large buffer to store the result of A @ B (1024**2 x 256**2 in my example), since the entire output of vmap is stored in memory. Instead, I need to map a function to a chunk of data, reduce the intermediate result, and then go to the next chunk.