Torch.matmul(): issues with memory and associativity

Hi all,

I have recently been interested in bilinear applications. I tried using the using the nn.Bilinear module, but kept running into out-of-memory runtime errors . As I was trying to diagnose where these errors came from, I stumbled upon a couple problems which I don’t really know how to tackle.

Firstly, torch.matmul() seems to run out of memory for reasons I don’t quite understand:

>>> import torch
>>> i1 = torch.rand(5000, 1, 768)
>>> i2 = torch.rand(5000, 768, 1)
>>> W = torch.rand(768,768)
>>> i1.matmul(W).shape
torch.Size([5000, 1, 768])
>>> W.matmul(i2).shape
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: $ Torch: not enough memory: you tried to allocate 10GB. Buy new RAM! at /pytorch/aten/src/TH/THGeneral.cpp:201

But the second call to matmul (W.matmul(i2)) should however produce a tensor containing as many elements as the first:

>>> import torch
>>> i1 = torch.rand(50, 1, 76)
>>> i2 = torch.rand(50, 76, 1)
>>> W = torch.rand(76,76)
>>> i1.matmul(W).shape
torch.Size([50, 1, 76])
>>> W.matmul(i2).shape
torch.Size([50, 76, 1])

So why am I running out of RAM for W.matmul(i2) if I have enough to compute i1.matmul(W) ?

Second, and perhaps related to that, I also noticed that calls to torch.matmul() aren’t associative:

>>> (i1.matmul(W.matmul(i2)) == i1.matmul(W).matmul(i2)).all()
tensor(0, dtype=torch.uint8)

When trying to assess by how much the two call orders differed, I can see that the difference is fairly small on average, and therefore might be due to rounding:

>>> (i1.matmul(W.matmul(i2)) - i1.matmul(W).matmul(i2)).mean()


  • What factor other than number of elements makes torch.matmul() run out of memory?
  • How can I ensure that the result of calling torch.matmul() multiple times is associative?

matmul is implemented inefficiently in PyTorch when it behaves like a batched matrix multiply and requires broadcasting. Please file a bug report.

The first case (i1.matmul(W)) is saved by a special-case in PyTorch that treats a (B x M x K) @ (K x N) as a matrix multiply (B*M x K) @ (K x N)

The second case expands the 768x768 tensor into a 5000x768x768 tensor which is huge.

matmul() isn’t associative because floating point arithmetic isn’t associative. This is generally the case and not specific to PyTorch.