I want to implement a gated matrix multiplication. In this version of the matrix multiplication, when the gate’s value is 0 it skips the matrix multiplication. So far I try to implement it in python but it throws Cuda out of memory when the dimensions are higher than 2:
import torch
x = torch.rand([70, 20, 1024])
g = torch.rand([70, 20, 96, 1, 1])
w = torch.rand([96, 128, 128])
g = g * w # cuda out of memory
g = g.view(70, 20, 1536, 1024)
res = (g@x.unsqueeze(-1)).squeeze(-1)
print(res.shape)
So I try to implement my own version because the problem occurs due to broadcasting in g*weight0. I try to find matmul function implementation in C++. I found the C++ implementation here but I don’t know how to access at::mm_out
function.