Sparse.mm autograd support

My NN model contains a sparse.mm() operation and when I try to train I get the following error:
RuntimeError: Expected object of backend CPU but got backend SparseCPU for argument #2 'mat2'

I’ve seen various posts related to this but I’m unclear as to what the work around/solution is other than converting it to a dense matrix and performing standard dense operations. The idea is to utilize torch.sparse.mm() for its performance.

Thanks!