Hi guys.
I’m studying the FEM in neural network with pytorch.
My question is existence of the ‘batch’ + ‘sparse’ + ‘matrix multiplication’ function in a single code.
There are several method for this:
- torch.bmm
- torch.sparse.mm
However, I cannot find the ‘batch’ + ‘sparse’ matrix multiplication in a single function.
Here is my data:
- batch sparse matrix size: (batch, 126, 126)
- batch dense vector size: (batch, 126, 1)
I think that the matrix would be extended 120,000 by 120,000 in size so that the memory issue can have to be dealt with.
Is there any approach to address ‘batch’ + ‘sparse’ matrix multiplication without memory issue?
Thank you!