Maybe the question can be formulated also as: does PyTorch support dense gemm with a sparse output mask? (so that the output is sparse and only computed where the mask is non-zero)
Is torch.sparse.sampled_addmm — PyTorch 2.0 documentation doing this?
Hmm, seems so, but does not support batches Implementation of torch.sparse.sampled_baddmm · Issue #105319 · pytorch/pytorch · GitHub