Large memory cost of grouped matrix product for ndim > 3

I’ve been seeing some strange behaviour on the GPU with torch.matmul for arguments with more than three dimensions. The following throws an out-of-memory error for an attempted allocation of 372 Gb (!).

y = torch.randn(1000000, 5, 1, 100).to(torch.device('cuda'))
x = torch.randn(5, 100, 200).to(torch.device('cuda'))
# out-of-memory
z = torch.matmul(y, x) 
# reducing number of dimensions solves the problem
z2 = torch.matmul(y.reshape((1000000, 1, 500)), x.reshape((500, 200))) 

Note I am on quite a large GPU (48Gb) so perhaps z2 might throw OOM for some people - if so, maybe trying reducing the first dimension of y.

Is this expected behaviour? I am trying to build a multi-head network so having that extra “head” dimension of 5 is helpful.

Hi Mgreenig!

“372 Gb” does seem to be unreasonably large.

But leaving that aside …

The shape of z will be [1000000, 5, 1, 200]. So you do need gigabytes
of memory just for z. Depending what else is already on your gpu, this
could reasonably push you over the edge and cause an out-of-memory
error, even with 48 GB.

Furthermore, the shape of z2 will be [1000000, 1, 200], so one-fifth
the size of z. That could conceivably be enough less to not cause the
out-of-memory error.

Best.

K. Frank

The “372GB” sound indeed too large so are you sure it’s not 3.72GB which is the expected memory requirement for the output?

(1000000 * 5 * 1 * 200 * 4) / 1024**3
# 3.725