Sparse Batch Matrix Multiplication

Hey guys,

I have a large sparse matrix (2D), e.g. [2000,2000] and I have batch data, let’s say of dimension [batch_size, 2000,3].

I need every batch to be multiplied by the sparse matrix. Both of the following work:

x = torch.stack([torch.mm(sparse_matrix, data[i,:].float()) for i in range(batch_size)])
x = torch.matmul(sparse_matrix.to_dense(), data)

However, what I want is this:

x = torch.matmul(sparse_matrix, data)

In other words, I want to use the 1.) the batch matrix multiplication and 2.) the parallelised process, so that I don’t need save the outcome for every batch in a list (because it is too slow for my use case).

I think this functionality is not implemented yet. Is that correct, or is there a work around to get it working?

Thanks so much!

2 Likes

Bump.

I would like to do the same thing:

torch.matmul(sparse_mat, batch)

Which raises

RuntimeError Traceback (most recent call last)
in ()
----> 1 torch.matmul(sparse_mat, batch)

RuntimeError: sparse tensors do not have strides

While

torch.matmul(sparse_mat.to_dense(), batch)

Runs without issue. I mentioned it in this github issue and I’m thinking about opening a new one as well, if only to find some suitable alternative without iterating.

1 Like