Hi all,
I have recently been interested in bilinear applications. I tried using the using the nn.Bilinear
module, but kept running into out-of-memory runtime errors . As I was trying to diagnose where these errors came from, I stumbled upon a couple problems which I don’t really know how to tackle.
Firstly, torch.matmul()
seems to run out of memory for reasons I don’t quite understand:
>>> import torch
>>> i1 = torch.rand(5000, 1, 768)
>>> i2 = torch.rand(5000, 768, 1)
>>> W = torch.rand(768,768)
>>> i1.matmul(W).shape
torch.Size([5000, 1, 768])
>>> W.matmul(i2).shape
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: $ Torch: not enough memory: you tried to allocate 10GB. Buy new RAM! at /pytorch/aten/src/TH/THGeneral.cpp:201
But the second call to matmul (W.matmul(i2)
) should however produce a tensor containing as many elements as the first:
>>> import torch
>>> i1 = torch.rand(50, 1, 76)
>>> i2 = torch.rand(50, 76, 1)
>>> W = torch.rand(76,76)
>>> i1.matmul(W).shape
torch.Size([50, 1, 76])
>>> W.matmul(i2).shape
torch.Size([50, 76, 1])
So why am I running out of RAM for W.matmul(i2)
if I have enough to compute i1.matmul(W)
?
Second, and perhaps related to that, I also noticed that calls to torch.matmul() aren’t associative:
>>> (i1.matmul(W.matmul(i2)) == i1.matmul(W).matmul(i2)).all()
tensor(0, dtype=torch.uint8)
When trying to assess by how much the two call orders differed, I can see that the difference is fairly small on average, and therefore might be due to rounding:
>>> (i1.matmul(W.matmul(i2)) - i1.matmul(W).matmul(i2)).mean()
tensor(-7.3242e-06)
TL;DR:
- What factor other than number of elements makes
torch.matmul()
run out of memory? - How can I ensure that the result of calling
torch.matmul()
multiple times is associative?