I have recently been interested in bilinear applications. I tried using the using the
nn.Bilinear module, but kept running into out-of-memory runtime errors . As I was trying to diagnose where these errors came from, I stumbled upon a couple problems which I don’t really know how to tackle.
torch.matmul() seems to run out of memory for reasons I don’t quite understand:
>>> import torch >>> i1 = torch.rand(5000, 1, 768) >>> i2 = torch.rand(5000, 768, 1) >>> W = torch.rand(768,768) >>> i1.matmul(W).shape torch.Size([5000, 1, 768]) >>> W.matmul(i2).shape Traceback (most recent call last): File "<stdin>", line 1, in <module> RuntimeError: $ Torch: not enough memory: you tried to allocate 10GB. Buy new RAM! at /pytorch/aten/src/TH/THGeneral.cpp:201
But the second call to matmul (
W.matmul(i2)) should however produce a tensor containing as many elements as the first:
>>> import torch >>> i1 = torch.rand(50, 1, 76) >>> i2 = torch.rand(50, 76, 1) >>> W = torch.rand(76,76) >>> i1.matmul(W).shape torch.Size([50, 1, 76]) >>> W.matmul(i2).shape torch.Size([50, 76, 1])
So why am I running out of RAM for
W.matmul(i2) if I have enough to compute
Second, and perhaps related to that, I also noticed that calls to torch.matmul() aren’t associative:
>>> (i1.matmul(W.matmul(i2)) == i1.matmul(W).matmul(i2)).all() tensor(0, dtype=torch.uint8)
When trying to assess by how much the two call orders differed, I can see that the difference is fairly small on average, and therefore might be due to rounding:
>>> (i1.matmul(W.matmul(i2)) - i1.matmul(W).matmul(i2)).mean() tensor(-7.3242e-06)
- What factor other than number of elements makes
torch.matmul()run out of memory?
- How can I ensure that the result of calling
torch.matmul()multiple times is associative?