 # Gated matmul implementation

I want to implement a gated matrix multiplication. In this version of the matrix multiplication, when the gate’s value is 0 it skips the matrix multiplication. So far I try to implement it in python but it throws Cuda out of memory when the dimensions are higher than 2:

``````import torch

x = torch.rand([70, 20, 1024])
g = torch.rand([70, 20, 96, 1, 1])
w = torch.rand([96, 128, 128])
g = g * w # cuda out of memory
g = g.view(70, 20, 1536, 1024)
res = (g@x.unsqueeze(-1)).squeeze(-1)
print(res.shape)
``````

So I‌ try to implement my own version because the problem occurs due to broadcasting in g*weight0. I try to find matmul function implementation in C++. I found the C++ implementation here but I don’t know how to access `at::mm_out` function. The equivalent of `mm_out` is `torch.mm(..., out=...)`.

So is the following correct? You have

``````x = x.view(70, 20, 1, 1, 1024)
g = g.view(70, 20, 96, 1, 1)
w = w.view(1, 1, 96, 16, 1024)
``````

and want to conceptually compute

``````(x*g*w).sum(-1)
``````

In this case, you might take a look at how `F.bilinear` is implemented (in Linear.cpp) - using a for lover the `96` entries. This avoids instantiating a too-large tensor.
One easy way to do something like this is to do

``````res = torch.empty(70, 20, 96, 16)
for i in range(96):
res[:,:, i] = (x[:,:,0]*g[:,:,i]*w[:,:,i]).sum(-1)
``````

Quite likely, the backward isn’t 100% efficient and you might use `matmul` instead of summing, but is should give you a reasonable performance.

For a small additional speedup, you can use `matmul`.

Viele Grüße

Thomas

1 Like

@tom Thanks for your reply and also the pytorch team for such an amazing community.

The `w` shape must be `[96, 128,128]` when multiply by `g`. So, I can’t reshape it to `[96, 16, 1024]`. You gave me a good pointer though.

But what’s your expectation with that? From the shapes you describe, it should not make a difference to the multiplication. (You can, of course, rename the view if you need the original shape afterwards.)

Best regards

Thomas