Batch sparse matrix multiplication

Hi guys.

I’m studying the FEM in neural network with pytorch.

My question is existence of the ‘batch’ + ‘sparse’ + ‘matrix multiplication’ function in a single code.
There are several method for this:

  • torch.bmm
  • torch.sparse.mm

However, I cannot find the ‘batch’ + ‘sparse’ matrix multiplication in a single function.

Here is my data:

  • batch sparse matrix size: (batch, 126, 126)
  • batch dense vector size: (batch, 126, 1)

I think that the matrix would be extended 120,000 by 120,000 in size so that the memory issue can have to be dealt with.
Is there any approach to address ‘batch’ + ‘sparse’ matrix multiplication without memory issue?

Thank you!