Gated matmul implementation

I want to implement a gated matrix multiplication. In this version of the matrix multiplication, when the gate’s value is 0 it skips the matrix multiplication. So far I try to implement it in python but it throws Cuda out of memory when the dimensions are higher than 2:

import torch

x = torch.rand([70, 20, 1024])
g = torch.rand([70, 20, 96, 1, 1])
w = torch.rand([96, 128, 128])
g = g * w # cuda out of memory
g = g.view(70, 20, 1536, 1024)
res = (g@x.unsqueeze(-1)).squeeze(-1)
print(res.shape)

So I‌ try to implement my own version because the problem occurs due to broadcasting in g*weight0. I try to find matmul function implementation in C++. I found the C++ implementation here but I don’t know how to access at::mm_out function.

image

The equivalent of mm_out is torch.mm(..., out=...).

So is the following correct? You have

x = x.view(70, 20, 1, 1, 1024)
g = g.view(70, 20, 96, 1, 1)
w = w.view(1, 1, 96, 16, 1024)

and want to conceptually compute

(x*g*w).sum(-1)

In this case, you might take a look at how F.bilinear is implemented (in Linear.cpp) - using a for lover the 96 entries. This avoids instantiating a too-large tensor.
One easy way to do something like this is to do

res = torch.empty(70, 20, 96, 16)
for i in range(96):
    res[:,:, i] = (x[:,:,0]*g[:,:,i]*w[:,:,i]).sum(-1)

Quite likely, the backward isn’t 100% efficient and you might use matmul instead of summing, but is should give you a reasonable performance.

For a small additional speedup, you can use matmul.

Viele Grüße

Thomas

1 Like

@tom Thanks for your reply and also the pytorch team for such an amazing community.

The w shape must be [96, 128,128] when multiply by g. So, I can’t reshape it to [96, 16, 1024]. You gave me a good pointer though.

But what’s your expectation with that? From the shapes you describe, it should not make a difference to the multiplication. (You can, of course, rename the view if you need the original shape afterwards.)

Best regards

Thomas