Unexpected huge memory cost of matmul()

I have some questions of memory cost of matmul() function. Specifically, I have a matrix A of size [4096, 4096], and a tensor v of size [192, 4096, 1]. What I want to do is to multiply A to the last two dimension of v and return the multiplication result of size [192, 4096, 1]. So I wrote

w = torch.matmul(A, v)

But I got a out of memory error saying pytorch needs to allocate 12 GB of memory while I have only 11. I checked that both A and v have requires_grad=False.

My question is that is this amount of memory cost normal or I just misused the matmul() function? If it is the way it should be used, is there any method to get rid of the out of memory error?

Hi,

The input tensor, once the batch dimension is added will be 192 x 4096 x 4096 that adds up to ~12GB of memory.
If you want to handle the batch dimension in a less memory hangry manner, I would suggest: w = torch.bmm(A.unsqueeze(0).expand_as(v), v).

2 Likes

So broadcasting will expand a tensor “physically”?

I didn’t check the code, but looks like it does.

Ahh I got the point. And torch.bmm works great for me. Thank you so much for your help!

This is a performance bug and should be fixed. I’ve written up a bug report here:

@liangbright, In general broadcasting does not usually expand a tensor “physically” (in memory), but in this case matmul unnecessarily expands the tensor by calling “contiguous()” on the expanded tensor.

2 Likes