Cuda Memory Use Batched Matrix Multiplication

I’m trying to perform a batched matrix multiplication on a GPU and I get a CUDA out of memory error that doesn’t make a lot of sense to me. I’m hoping someone can shed some light on it.

Specifically, I have a tensor A of size [128, 1024, 1, 1, 27] and another tensor B of size [1, 1, 160, 27, 27]

When I try to multiply them I get a CUDA out of memory error that says “Tried to allocate 56.95 GiB”
I don’t understand why it would require so much space? It’s not just the size of the output tensor because that’s be [128, 1024, 160, 1, 27] which is more on the order of 2 GiB.

My main question is whether there’s any method I can use to compute this batched matrix multiply efficiently? Any advice is greatly appreciated.

Is this the only operation that you’re doing or is it a part of several steps?

Try using the following method to compute the actual memory footprint of your matrices.
a.element_size * a.nelement()

Since the matrices are itself 5D, I’m guessing a significant amount of memory would be spent in storing the intermediate values as well.

Hi, thanks for the quick response!

As far as your questions, the same error occurs when the operation is done in isolation with no gradient tracking.

The memory footprint of A is 0.01416 GiB,
B is 0.00047 GiB
and a matrix with the output shape has a memory footprint of 2.2649 GiB.

I can understand that some amount of extra memory is used for intermediate values during the computation but I can’t imagine how 54 GB is needed.

Thanks again for your response, I hope this is helpful context.

Unfortunately when doing batched matrix multiply, and batched dimensions cannot be flattened (as in your case) full matrices are materialized. In your case, B will turn into (128, 1024, 160, 27, 27) matrix and will require ~60GB memory. This is a limitation of underlying cublas function that we are using, that assumes batch elements are at regular positions in memory.

Thanks for the explanation! Am I right in thinking that the only way to solve this on my end is to use a for loop or avoid the multiplication entirely?

You can try representing your op as (A*B.transpose(-1, -2)).sum(-1) and then use keops Kernel Operations on the GPU, with autodiff, without memory overflows — KeOps to avoid materializing large intermediate tensor. In future versions of pytorch you will also be able to codegen a kernel for this operation without materializing large intermediate), but in the current version your easiest option is a loop (or keops).

I’ll take a look at that library. Thanks a lot for your help.