Matrix multiplication with values only at masked entries

I want to multiply two dense matrices A(1000000,300) and B(1000000,300). The resulting matrix will be of size million x million, which offcourse won’t fit into memory and will take forever to compute. However, I want values only at few elements which are specified in a sparse matrix C(1000000,1000000).
Is there an efficient(possibly that utilizes GPUs) way to do this on Pytorch?

Did you try just iterating over the required elements in C matrix and take dot product of corresponding row of A and column of B? Of-course i believe there can more efficient solutions, but i think this is the most basic one.

While iterating you can also use multiprocessing to run the computations on various GPU’s for various required elements in C.

yes but that takes a lot of time. I wanted a vectorized solution that uses inbuilt functions

There’s a way to do this if we restrict your mask M:
Let’s say you have a mask M such that no two 1’s in the mask are in the same row or column.

Then the following should equivalent to (z @ y) * M, where the @ sign is matrix multiplication:
(z.t() * (y @ M.t())).to_dense().sum(dim=0).diag() @ M

I think pytorch does support sparse x dense -> sparse via torch.mm. If you need a dense x sparse -> sparse (because M will probably be sparse), you can use the identity AB = ( AB )^T ^T = (B^T A^T)^T.

I haven’t actually tried out the expression above for your use case, though.

Thanks a lot Richard! Unfortunately, the mask I have contains multiple ones in each row. For instance for an M of size million x million there are about 300 ones in each row.
Is there any solution for this?

Any more suggestions to solve this??

If you have your mask 1 locations as a rows and columns LongTensors, so for example (your matrices are larger, but for demonstration purposes)

A = torch.randn(5,10)
B = torch.randn(8,10)
rows = torch.arange(0, 16, out=torch.LongTensor())%5
cols = (torch.arange(0, 16, out=torch.LongTensor())/2)%8
idxes = torch.stack([rows,cols], dim=0)
mask = torch.sparse.FloatTensor(idxes, torch.ones(16)).to_dense()

how about doing

vals = (A[rows]*B[cols]).sum(1)
res = torch.sparse.FloatTensor(idxes, vals)

Then you see that the difference to the masked product is 0 up to machine precision:

print((res.to_dense()-(A@B.t()) * mask).abs().max())

Now this will have quite a bit of memory overhead, but you could split it into chunks of nonzero entries and cat them together afterwards if needed.

Best regards

Thomas

2 Likes