I’ve been seeing some strange behaviour on the GPU with torch.matmul for arguments with more than three dimensions. The following throws an out-of-memory error for an attempted allocation of 372 Gb (!).
y = torch.randn(1000000, 5, 1, 100).to(torch.device('cuda'))
x = torch.randn(5, 100, 200).to(torch.device('cuda'))
z = torch.matmul(y, x)
# reducing number of dimensions solves the problem
z2 = torch.matmul(y.reshape((1000000, 1, 500)), x.reshape((500, 200)))
Note I am on quite a large GPU (48Gb) so perhaps z2 might throw OOM for some people - if so, maybe trying reducing the first dimension of y.
Is this expected behaviour? I am trying to build a multi-head network so having that extra “head” dimension of 5 is helpful.
The shape of z will be [1000000, 5, 1, 200]. So you do need gigabytes
of memory just for z. Depending what else is already on your gpu, this
could reasonably push you over the edge and cause an out-of-memory
error, even with 48 GB.
Furthermore, the shape of z2 will be [1000000, 1, 200], so one-fifth
the size of z. That could conceivably be enough less to not cause the