Two sparse matrix multiplication but one of them is diagonal matrix

Hi. I have been looking for the sparse matrix multiplication: mm(S, S)->S for a long time and it seems there is no implementation for this. But my problem can be reduced to the matrix multiplication of a sparse matrix S and a sparse diagonal matrix D, and only D requires gradient. A doable solution is to define a forward and backward methods, using scipy to do the multiplication and manually compute the gradient (since it is just diagonal matrix it’s quite straightforward). However, doing so doesn’t allow me to use GPU. What are possible methods to make the multiplication more efficient? I want to try some math tricks but I cannot find one.

BTW my sparse matrices are all complex and the dense dimension can easily be too large to be converted to a dense matrix on my local machines… Thanks for all you guys! :smiley: