Memory efficient way to implement masked matrix multiplication

I want to multiply two dense matrices A(N, d) and B(d, N). The resulting matrix will be of size N x N, where N is very large and I can’t use normal matrix multiplication function to do this.
However, I only want values at a few positions which are specified in a sparse matrix C(N, N). For example, I only want E non-zero entries in C.
My problem is the same as discussed in Matrix multiplication with values only at masked entries.
However, I don’t find a memory efficient way to solve this problem.
The solution provided by tom take a memory complexity of O(Ed).
Spliting it into chunks does not help as it still occupies memory for gradient ackpropagation.

matmul doesn’t store output AB, but your further indexing does. so, you can try writing custom autograd.Function (A,B) (returning (A@B)[indexes]) that would use sparse gradient of AB in backward (then do sparse-dense matmuls to compute gradients of A and B).

PS: I mean using chunked matmuls in forward, can’t think of a better way atm

I can’t do matmuls A@B there as it takes a time complexity of O(N^2 d).
Specifically, I have a gpraph with N nodes and E edges.
N is several millions and E is dozens of millions.

if you have memory problem with this:

vals = (A[rows]*B[cols]).sum(1)

maybe make it a function and wrap it with utils.checkpoint

It works well with my code, thanks a lot.