Hey guys,
I have a large sparse matrix (2D), e.g. [2000,2000]
and I have batch data, let’s say of dimension [batch_size, 2000,3]
.
I need every batch to be multiplied by the sparse matrix. Both of the following work:
x = torch.stack([torch.mm(sparse_matrix, data[i,:].float()) for i in range(batch_size)])
x = torch.matmul(sparse_matrix.to_dense(), data)
However, what I want is this:
x = torch.matmul(sparse_matrix, data)
In other words, I want to use the 1.) the batch matrix multiplication and 2.) the parallelised process, so that I don’t need save the outcome for every batch in a list (because it is too slow for my use case).
I think this functionality is not implemented yet. Is that correct, or is there a work around to get it working?
Thanks so much!