# Memory optimization for compositions of skinny matrix multiplications with pointwise operations

Let’s say I have some composition of skinny matrix multiplications and pointwise operations:

``````A = torch.rand(1024**2, 5)
B = torch.rand(5, 256**2)
C = torch.rand(256**2, 2)
result = torch.exp(A @ B) @ C
``````

When calculated naively, the size of the matrix A @ B will be huge, and the memory utilization can be dramatically reduced by instead calculating the required matrix-vector products on-the-fly:

``````result = torch.zeros(1024**2, 2)
for i in range(256**2):
result += torch.exp(A @ B[:, i]) * C[i, :]
``````

Does anyone know if there is a generic way to achieve this kind of memory-targeting optimization with skinny matrix multiplications in PyTorch, perhaps either by reformulating the computation, or by using an optimizing compiler / other external tool?

Many thanks for any help.

You could try using `vmap` from functorch (which comes packaged with the latest version of PyTorch). On the experimental namespace, there’s a ‘chunk’ version of `vmap`, which allows mini-batches of calculation to be computed in parallel (vmap example here: GitHub - pytorch/functorch: functorch is JAX-like composable function transforms for PyTorch.). Using the chunk version of vmap will allow you to compute mini-batches (subject to memory constraints) in parallel, which should speed up the for example above.
Unfortunately, just `vmap` doesn’t seem to solve the memory problem, since the main memory optimization comes from the sequential reduction in the for loop. I guess I would need a (chunked) map-reduce operation to achieve what I want. It seems that JAX can do the required optimization using JIT compilation (Reduce functionality for vmap · Discussion #9505 · google/jax · GitHub), but if I understood correctly, `functorch` is not supported by TorchScript yet.
Functorch does have a chunked vmap, it’s not on the docs (as I shared the standard vmap docs) but it’s in there. Instead of `functorch.vmap`, you use `functorch.experimental.chunk_vmap`, more info this issue: vmap, jacrev should have an option to vectorize "chunks" instead of the entire batch at a time · Issue #680 · pytorch/functorch · GitHub
The problem is that even the chunked version of `vmap` needs a large buffer to store the result of `A @ B` (`1024**2 x 256**2` in my example), since the entire output of `vmap` is stored in memory. Instead, I need to map a function to a chunk of data, reduce the intermediate result, and then go to the next chunk.