I’m trying to perform a batched matrix multiplication on a GPU and I get a CUDA out of memory error that doesn’t make a lot of sense to me. I’m hoping someone can shed some light on it.
Specifically, I have a tensor A of size [128, 1024, 1, 1, 27] and another tensor B of size [1, 1, 160, 27, 27]
When I try to multiply them I get a CUDA out of memory error that says “Tried to allocate 56.95 GiB”
I don’t understand why it would require so much space? It’s not just the size of the output tensor because that’s be [128, 1024, 160, 1, 27] which is more on the order of 2 GiB.
My main question is whether there’s any method I can use to compute this batched matrix multiply efficiently? Any advice is greatly appreciated.
Unfortunately when doing batched matrix multiply, and batched dimensions cannot be flattened (as in your case) full matrices are materialized. In your case, B will turn into (128, 1024, 160, 27, 27) matrix and will require ~60GB memory. This is a limitation of underlying cublas function that we are using, that assumes batch elements are at regular positions in memory.
Thanks for the explanation! Am I right in thinking that the only way to solve this on my end is to use a for loop or avoid the multiplication entirely?
You can try representing your op as (A*B.transpose(-1, -2)).sum(-1) and then use keops Kernel Operations on the GPU, with autodiff, without memory overflows — KeOps to avoid materializing large intermediate tensor. In future versions of pytorch you will also be able to codegen a kernel for this operation without materializing large intermediate), but in the current version your easiest option is a loop (or keops).